From 6527ed941684cffb8c411fe83e6ec5d334617d0c Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 19 Sep 2023 21:50:29 +0000 Subject: [PATCH 1/3] feat: Upgrade to API v1beta3 feat: Add model tuning feat: Add permissions service feat: Add BatchEmbedText and CountTextTokens to the text service PiperOrigin-RevId: 566721485 Source-Link: https://github.com/googleapis/googleapis/commit/87083bffa08ef38a7a1bdab0565a20f9d3b11692 Source-Link: https://github.com/googleapis/googleapis-gen/commit/455ab5756df3a2a9cdae24d061ca832ab1ba5dca Copy-Tag: eyJwIjoicGFja2FnZXMvZ29vZ2xlLWFpLWdlbmVyYXRpdmVsYW5ndWFnZS8uT3dsQm90LnlhbWwiLCJoIjoiNDU1YWI1NzU2ZGYzYTJhOWNkYWUyNGQwNjFjYTgzMmFiMWJhNWRjYSJ9 --- .../v1beta2/.coveragerc | 13 + .../v1beta2/.flake8 | 33 + .../v1beta2/MANIFEST.in | 2 + .../v1beta2/README.rst | 49 + .../v1beta2/docs/_static/custom.css | 3 + .../v1beta2/docs/conf.py | 376 ++ .../discuss_service.rst | 6 + .../model_service.rst | 10 + .../generativelanguage_v1beta2/services.rst | 8 + .../text_service.rst | 6 + .../docs/generativelanguage_v1beta2/types.rst | 6 + .../v1beta2/docs/index.rst | 7 + .../google/ai/generativelanguage/__init__.py | 85 + .../ai/generativelanguage/gapic_version.py | 16 + .../google/ai/generativelanguage/py.typed | 2 + .../ai/generativelanguage_v1beta2/__init__.py | 86 + .../gapic_metadata.json | 156 + .../gapic_version.py | 16 + .../ai/generativelanguage_v1beta2/py.typed | 2 + .../services/__init__.py | 15 + .../services/discuss_service/__init__.py | 22 + .../services/discuss_service/async_client.py | 508 ++ .../services/discuss_service/client.py | 712 +++ .../discuss_service/transports/__init__.py | 38 + .../discuss_service/transports/base.py | 161 + .../discuss_service/transports/grpc.py | 295 + .../transports/grpc_asyncio.py | 294 + .../discuss_service/transports/rest.py | 433 ++ .../services/model_service/__init__.py | 22 + .../services/model_service/async_client.py | 431 ++ .../services/model_service/client.py | 635 +++ .../services/model_service/pagers.py | 140 + .../model_service/transports/__init__.py | 38 + .../services/model_service/transports/base.py | 162 + .../services/model_service/transports/grpc.py | 292 + .../model_service/transports/grpc_asyncio.py | 291 + .../services/model_service/transports/rest.py | 397 ++ .../services/text_service/__init__.py | 22 + .../services/text_service/async_client.py | 514 ++ .../services/text_service/client.py | 718 +++ .../text_service/transports/__init__.py | 38 + .../services/text_service/transports/base.py | 161 + .../services/text_service/transports/grpc.py | 295 + .../text_service/transports/grpc_asyncio.py | 294 + .../services/text_service/transports/rest.py | 423 ++ .../types/__init__.py | 80 + .../types/citation.py | 102 + .../types/discuss_service.py | 358 ++ .../generativelanguage_v1beta2/types/model.py | 156 + .../types/model_service.py | 114 + .../types/safety.py | 247 + .../types/text_service.py | 333 ++ .../v1beta2/mypy.ini | 3 + .../v1beta2/noxfile.py | 184 + ...cuss_service_count_message_tokens_async.py | 56 + ...scuss_service_count_message_tokens_sync.py | 56 + ..._discuss_service_generate_message_async.py | 56 + ...d_discuss_service_generate_message_sync.py | 56 + ...generated_model_service_get_model_async.py | 52 + ..._generated_model_service_get_model_sync.py | 52 + ...nerated_model_service_list_models_async.py | 52 + ...enerated_model_service_list_models_sync.py | 52 + ...generated_text_service_embed_text_async.py | 53 + ..._generated_text_service_embed_text_sync.py | 53 + ...erated_text_service_generate_text_async.py | 56 + ...nerated_text_service_generate_text_sync.py | 56 + ..._google.ai.generativelanguage.v1beta2.json | 1093 ++++ ...xup_generativelanguage_v1beta2_keywords.py | 181 + .../v1beta2/setup.py | 90 + .../v1beta2/testing/constraints-3.10.txt | 6 + .../v1beta2/testing/constraints-3.11.txt | 6 + .../v1beta2/testing/constraints-3.12.txt | 6 + .../v1beta2/testing/constraints-3.7.txt | 9 + .../v1beta2/testing/constraints-3.8.txt | 6 + .../v1beta2/testing/constraints-3.9.txt | 6 + .../v1beta2/tests/__init__.py | 16 + .../v1beta2/tests/unit/__init__.py | 16 + .../v1beta2/tests/unit/gapic/__init__.py | 16 + .../generativelanguage_v1beta2/__init__.py | 16 + .../test_discuss_service.py | 2205 ++++++++ .../test_model_service.py | 2319 ++++++++ .../test_text_service.py | 2214 ++++++++ .../v1beta3/.coveragerc | 13 + .../v1beta3/.flake8 | 33 + .../v1beta3/MANIFEST.in | 2 + .../v1beta3/README.rst | 49 + .../v1beta3/docs/_static/custom.css | 3 + .../v1beta3/docs/conf.py | 376 ++ .../discuss_service.rst | 6 + .../model_service.rst | 10 + .../permission_service.rst | 10 + .../generativelanguage_v1beta3/services.rst | 9 + .../text_service.rst | 6 + .../docs/generativelanguage_v1beta3/types.rst | 6 + .../v1beta3/docs/index.rst | 7 + .../google/ai/generativelanguage/__init__.py | 145 + .../ai/generativelanguage/gapic_version.py | 16 + .../google/ai/generativelanguage/py.typed | 2 + .../ai/generativelanguage_v1beta3/__init__.py | 146 + .../gapic_metadata.json | 370 ++ .../gapic_version.py | 16 + .../ai/generativelanguage_v1beta3/py.typed | 2 + .../services/__init__.py | 15 + .../services/discuss_service/__init__.py | 22 + .../services/discuss_service/async_client.py | 509 ++ .../services/discuss_service/client.py | 717 +++ .../discuss_service/transports/__init__.py | 38 + .../discuss_service/transports/base.py | 162 + .../discuss_service/transports/grpc.py | 296 + .../transports/grpc_asyncio.py | 295 + .../discuss_service/transports/rest.py | 434 ++ .../services/model_service/__init__.py | 22 + .../services/model_service/async_client.py | 996 ++++ .../services/model_service/client.py | 1213 ++++ .../services/model_service/pagers.py | 262 + .../model_service/transports/__init__.py | 38 + .../services/model_service/transports/base.py | 242 + .../services/model_service/transports/grpc.py | 449 ++ .../model_service/transports/grpc_asyncio.py | 448 ++ .../services/model_service/transports/rest.py | 972 ++++ .../services/permission_service/__init__.py | 22 + .../permission_service/async_client.py | 876 +++ .../services/permission_service/client.py | 1094 ++++ .../services/permission_service/pagers.py | 140 + .../permission_service/transports/__init__.py | 38 + .../permission_service/transports/base.py | 221 + .../permission_service/transports/grpc.py | 402 ++ .../transports/grpc_asyncio.py | 401 ++ .../permission_service/transports/rest.py | 919 ++++ .../services/text_service/__init__.py | 22 + .../services/text_service/async_client.py | 760 +++ .../services/text_service/client.py | 968 ++++ .../text_service/transports/__init__.py | 38 + .../services/text_service/transports/base.py | 190 + .../services/text_service/transports/grpc.py | 350 ++ .../text_service/transports/grpc_asyncio.py | 349 ++ .../services/text_service/transports/rest.py | 676 +++ .../types/__init__.py | 142 + .../types/citation.py | 102 + .../types/discuss_service.py | 358 ++ .../generativelanguage_v1beta3/types/model.py | 156 + .../types/model_service.py | 310 ++ .../types/permission.py | 140 + .../types/permission_service.py | 220 + .../types/safety.py | 250 + .../types/text_service.py | 431 ++ .../types/tuned_model.py | 414 ++ .../v1beta3/mypy.ini | 3 + .../v1beta3/noxfile.py | 184 + ...cuss_service_count_message_tokens_async.py | 56 + ...scuss_service_count_message_tokens_sync.py | 56 + ..._discuss_service_generate_message_async.py | 56 + ...d_discuss_service_generate_message_sync.py | 56 + ..._model_service_create_tuned_model_async.py | 60 + ...d_model_service_create_tuned_model_sync.py | 60 + ..._model_service_delete_tuned_model_async.py | 50 + ...d_model_service_delete_tuned_model_sync.py | 50 + ...generated_model_service_get_model_async.py | 52 + ..._generated_model_service_get_model_sync.py | 52 + ...ted_model_service_get_tuned_model_async.py | 52 + ...ated_model_service_get_tuned_model_sync.py | 52 + ...nerated_model_service_list_models_async.py | 52 + ...enerated_model_service_list_models_sync.py | 52 + ...d_model_service_list_tuned_models_async.py | 52 + ...ed_model_service_list_tuned_models_sync.py | 52 + ..._model_service_update_tuned_model_async.py | 56 + ...d_model_service_update_tuned_model_sync.py | 56 + ...mission_service_create_permission_async.py | 52 + ...rmission_service_create_permission_sync.py | 52 + ...mission_service_delete_permission_async.py | 50 + ...rmission_service_delete_permission_sync.py | 50 + ...permission_service_get_permission_async.py | 52 + ..._permission_service_get_permission_sync.py | 52 + ...rmission_service_list_permissions_async.py | 53 + ...ermission_service_list_permissions_sync.py | 53 + ...ission_service_transfer_ownership_async.py | 53 + ...mission_service_transfer_ownership_sync.py | 53 + ...mission_service_update_permission_async.py | 51 + ...rmission_service_update_permission_sync.py | 51 + ...ted_text_service_batch_embed_text_async.py | 53 + ...ated_text_service_batch_embed_text_sync.py | 53 + ...ed_text_service_count_text_tokens_async.py | 56 + ...ted_text_service_count_text_tokens_sync.py | 56 + ...generated_text_service_embed_text_async.py | 53 + ..._generated_text_service_embed_text_sync.py | 53 + ...erated_text_service_generate_text_async.py | 56 + ...nerated_text_service_generate_text_sync.py | 56 + ..._google.ai.generativelanguage.v1beta3.json | 3222 +++++++++++ ...xup_generativelanguage_v1beta3_keywords.py | 194 + .../v1beta3/setup.py | 90 + .../v1beta3/testing/constraints-3.10.txt | 6 + .../v1beta3/testing/constraints-3.11.txt | 6 + .../v1beta3/testing/constraints-3.12.txt | 6 + .../v1beta3/testing/constraints-3.7.txt | 9 + .../v1beta3/testing/constraints-3.8.txt | 6 + .../v1beta3/testing/constraints-3.9.txt | 6 + .../v1beta3/tests/__init__.py | 16 + .../v1beta3/tests/unit/__init__.py | 16 + .../v1beta3/tests/unit/gapic/__init__.py | 16 + .../generativelanguage_v1beta3/__init__.py | 16 + .../test_discuss_service.py | 2206 ++++++++ .../test_model_service.py | 4869 +++++++++++++++++ .../test_permission_service.py | 4220 ++++++++++++++ .../test_text_service.py | 3177 +++++++++++ 204 files changed, 57309 insertions(+) create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/gapic_version.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/py.typed create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/discuss_service.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/model_service.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/permission_service.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/services.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/text_service.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/types.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_metadata.json create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/client.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/citation.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/discuss_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/safety.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/text_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/tuned_model.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/scripts/fixup_generativelanguage_v1beta3_keywords.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py create mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc b/owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc new file mode 100644 index 000000000000..fd060ae956b5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc @@ -0,0 +1,13 @@ +[run] +branch = True + +[report] +show_missing = True +omit = + google/ai/generativelanguage/__init__.py + google/ai/generativelanguage/gapic_version.py +exclude_lines = + # Re-enable the standard pragma + pragma: NO COVER + # Ignore debug-only repr + def __repr__ diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 b/owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 new file mode 100644 index 000000000000..29227d4cf419 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by synthtool. DO NOT EDIT! +[flake8] +ignore = E203, E266, E501, W503 +exclude = + # Exclude generated code. + **/proto/** + **/gapic/** + **/services/** + **/types/** + *_pb2.py + + # Standard linting exemptions. + **/.nox/** + __pycache__, + .git, + *.pyc, + conf.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in b/owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in new file mode 100644 index 000000000000..27e3433a8451 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in @@ -0,0 +1,2 @@ +recursive-include google/ai/generativelanguage *.py +recursive-include google/ai/generativelanguage_v1beta2 *.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst new file mode 100644 index 000000000000..099f73894711 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst @@ -0,0 +1,49 @@ +Python Client for Google Ai Generativelanguage API +================================================= + +Quick Start +----------- + +In order to use this library, you first need to go through the following steps: + +1. `Select or create a Cloud Platform project.`_ +2. `Enable billing for your project.`_ +3. Enable the Google Ai Generativelanguage API. +4. `Setup Authentication.`_ + +.. _Select or create a Cloud Platform project.: https://console.cloud.google.com/project +.. _Enable billing for your project.: https://cloud.google.com/billing/docs/how-to/modify-project#enable_billing_for_a_project +.. _Setup Authentication.: https://googleapis.dev/python/google-api-core/latest/auth.html + +Installation +~~~~~~~~~~~~ + +Install this library in a `virtualenv`_ using pip. `virtualenv`_ is a tool to +create isolated Python environments. The basic problem it addresses is one of +dependencies and versions, and indirectly permissions. + +With `virtualenv`_, it's possible to install this library without needing system +install permissions, and without clashing with the installed system +dependencies. + +.. _`virtualenv`: https://virtualenv.pypa.io/en/latest/ + + +Mac/Linux +^^^^^^^^^ + +.. code-block:: console + + python3 -m venv + source /bin/activate + /bin/pip install /path/to/library + + +Windows +^^^^^^^ + +.. code-block:: console + + python3 -m venv + \Scripts\activate + \Scripts\pip.exe install \path\to\library diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css new file mode 100644 index 000000000000..06423be0b592 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css @@ -0,0 +1,3 @@ +dl.field-list > dt { + min-width: 100px +} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py new file mode 100644 index 000000000000..0f3f4903ff54 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py @@ -0,0 +1,376 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# google-ai-generativelanguage documentation build configuration file +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys +import os +import shlex + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath("..")) + +__version__ = "0.1.0" + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +needs_sphinx = "4.0.1" + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.coverage", + "sphinx.ext.napoleon", + "sphinx.ext.todo", + "sphinx.ext.viewcode", +] + +# autodoc/autosummary flags +autoclass_content = "both" +autodoc_default_flags = ["members"] +autosummary_generate = True + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# Allow markdown includes (so releases.md can include CHANGLEOG.md) +# http://www.sphinx-doc.org/en/master/markdown.html +source_parsers = {".md": "recommonmark.parser.CommonMarkParser"} + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +source_suffix = [".rst", ".md"] + +# The encoding of source files. +# source_encoding = 'utf-8-sig' + +# The root toctree document. +root_doc = "index" + +# General information about the project. +project = u"google-ai-generativelanguage" +copyright = u"2023, Google, LLC" +author = u"Google APIs" # TODO: autogenerate this bit + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The full version, including alpha/beta/rc tags. +release = __version__ +# The short X.Y version. +version = ".".join(release.split(".")[0:2]) + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'en' + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ["_build"] + +# The reST default role (used for this markup: `text`) to use for all +# documents. +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +# keep_warnings = False + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "alabaster" + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + "description": "Google Ai Client Libraries for Python", + "github_user": "googleapis", + "github_repo": "google-cloud-python", + "github_banner": True, + "font_family": "'Roboto', Georgia, sans", + "head_font_family": "'Roboto', Georgia, serif", + "code_font_family": "'Roboto Mono', 'Consolas', monospace", +} + +# Add any paths that contain custom themes here, relative to this directory. +# html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +# html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +# html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +# html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +# html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# Add any extra paths that contain custom files (such as robots.txt or +# .htaccess) here, relative to this directory. These files are copied +# directly to the root of the documentation. +# html_extra_path = [] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +# html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +# html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +# html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +# html_additional_pages = {} + +# If false, no module index is generated. +# html_domain_indices = True + +# If false, no index is generated. +# html_use_index = True + +# If true, the index is split into individual pages for each letter. +# html_split_index = False + +# If true, links to the reST sources are added to the pages. +# html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +# html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +# html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +# html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +# html_file_suffix = None + +# Language to be used for generating the HTML full-text search index. +# Sphinx supports the following languages: +# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' +# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' +# html_search_language = 'en' + +# A dictionary with options for the search language support, empty by default. +# Now only 'ja' uses this config value +# html_search_options = {'type': 'default'} + +# The name of a javascript file (relative to the configuration directory) that +# implements a search results scorer. If empty, the default will be used. +# html_search_scorer = 'scorer.js' + +# Output file base name for HTML help builder. +htmlhelp_basename = "google-ai-generativelanguage-doc" + +# -- Options for warnings ------------------------------------------------------ + + +suppress_warnings = [ + # Temporarily suppress this to avoid "more than one target found for + # cross-reference" warning, which are intractable for us to avoid while in + # a mono-repo. + # See https://github.com/sphinx-doc/sphinx/blob + # /2a65ffeef5c107c19084fabdd706cdff3f52d93c/sphinx/domains/python.py#L843 + "ref.python" +] + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', + # Latex figure (float) alignment + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + ( + root_doc, + "google-ai-generativelanguage.tex", + u"google-ai-generativelanguage Documentation", + author, + "manual", + ) +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +# latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +# latex_use_parts = False + +# If true, show page references after internal links. +# latex_show_pagerefs = False + +# If true, show URL addresses after external links. +# latex_show_urls = False + +# Documents to append as an appendix to all manuals. +# latex_appendices = [] + +# If false, no module index is generated. +# latex_domain_indices = True + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ( + root_doc, + "google-ai-generativelanguage", + u"Google Ai Generativelanguage Documentation", + [author], + 1, + ) +] + +# If true, show URL addresses after external links. +# man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + root_doc, + "google-ai-generativelanguage", + u"google-ai-generativelanguage Documentation", + author, + "google-ai-generativelanguage", + "GAPIC library for Google Ai Generativelanguage API", + "APIs", + ) +] + +# Documents to append as an appendix to all manuals. +# texinfo_appendices = [] + +# If false, no module index is generated. +# texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +# texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +# texinfo_no_detailmenu = False + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "python": ("http://python.readthedocs.org/en/latest/", None), + "gax": ("https://gax-python.readthedocs.org/en/latest/", None), + "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), + "google-gax": ("https://gax-python.readthedocs.io/en/latest/", None), + "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None), + "grpc": ("https://grpc.io/grpc/python/", None), + "requests": ("http://requests.kennethreitz.org/en/stable/", None), + "proto": ("https://proto-plus-python.readthedocs.io/en/stable", None), + "protobuf": ("https://googleapis.dev/python/protobuf/latest/", None), +} + + +# Napoleon settings +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst new file mode 100644 index 000000000000..be72af9f8e59 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst @@ -0,0 +1,6 @@ +DiscussService +-------------------------------- + +.. automodule:: google.ai.generativelanguage_v1beta2.services.discuss_service + :members: + :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst new file mode 100644 index 000000000000..7edf8f7f17c5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst @@ -0,0 +1,10 @@ +ModelService +------------------------------ + +.. automodule:: google.ai.generativelanguage_v1beta2.services.model_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1beta2.services.model_service.pagers + :members: + :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst new file mode 100644 index 000000000000..e9e01c10ac08 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst @@ -0,0 +1,8 @@ +Services for Google Ai Generativelanguage v1beta2 API +===================================================== +.. toctree:: + :maxdepth: 2 + + discuss_service + model_service + text_service diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst new file mode 100644 index 000000000000..f30551e0f177 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst @@ -0,0 +1,6 @@ +TextService +----------------------------- + +.. automodule:: google.ai.generativelanguage_v1beta2.services.text_service + :members: + :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst new file mode 100644 index 000000000000..81b702c1c9e1 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst @@ -0,0 +1,6 @@ +Types for Google Ai Generativelanguage v1beta2 API +================================================== + +.. automodule:: google.ai.generativelanguage_v1beta2.types + :members: + :show-inheritance: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst new file mode 100644 index 000000000000..c5b70436ea18 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst @@ -0,0 +1,7 @@ +API Reference +------------- +.. toctree:: + :maxdepth: 2 + + generativelanguage_v1beta2/services + generativelanguage_v1beta2/types diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py new file mode 100644 index 000000000000..16becd33efb7 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.ai.generativelanguage import gapic_version as package_version + +__version__ = package_version.__version__ + + +from google.ai.generativelanguage_v1beta2.services.discuss_service.client import DiscussServiceClient +from google.ai.generativelanguage_v1beta2.services.discuss_service.async_client import DiscussServiceAsyncClient +from google.ai.generativelanguage_v1beta2.services.model_service.client import ModelServiceClient +from google.ai.generativelanguage_v1beta2.services.model_service.async_client import ModelServiceAsyncClient +from google.ai.generativelanguage_v1beta2.services.text_service.client import TextServiceClient +from google.ai.generativelanguage_v1beta2.services.text_service.async_client import TextServiceAsyncClient + +from google.ai.generativelanguage_v1beta2.types.citation import CitationMetadata +from google.ai.generativelanguage_v1beta2.types.citation import CitationSource +from google.ai.generativelanguage_v1beta2.types.discuss_service import CountMessageTokensRequest +from google.ai.generativelanguage_v1beta2.types.discuss_service import CountMessageTokensResponse +from google.ai.generativelanguage_v1beta2.types.discuss_service import Example +from google.ai.generativelanguage_v1beta2.types.discuss_service import GenerateMessageRequest +from google.ai.generativelanguage_v1beta2.types.discuss_service import GenerateMessageResponse +from google.ai.generativelanguage_v1beta2.types.discuss_service import Message +from google.ai.generativelanguage_v1beta2.types.discuss_service import MessagePrompt +from google.ai.generativelanguage_v1beta2.types.model import Model +from google.ai.generativelanguage_v1beta2.types.model_service import GetModelRequest +from google.ai.generativelanguage_v1beta2.types.model_service import ListModelsRequest +from google.ai.generativelanguage_v1beta2.types.model_service import ListModelsResponse +from google.ai.generativelanguage_v1beta2.types.safety import ContentFilter +from google.ai.generativelanguage_v1beta2.types.safety import SafetyFeedback +from google.ai.generativelanguage_v1beta2.types.safety import SafetyRating +from google.ai.generativelanguage_v1beta2.types.safety import SafetySetting +from google.ai.generativelanguage_v1beta2.types.safety import HarmCategory +from google.ai.generativelanguage_v1beta2.types.text_service import Embedding +from google.ai.generativelanguage_v1beta2.types.text_service import EmbedTextRequest +from google.ai.generativelanguage_v1beta2.types.text_service import EmbedTextResponse +from google.ai.generativelanguage_v1beta2.types.text_service import GenerateTextRequest +from google.ai.generativelanguage_v1beta2.types.text_service import GenerateTextResponse +from google.ai.generativelanguage_v1beta2.types.text_service import TextCompletion +from google.ai.generativelanguage_v1beta2.types.text_service import TextPrompt + +__all__ = ('DiscussServiceClient', + 'DiscussServiceAsyncClient', + 'ModelServiceClient', + 'ModelServiceAsyncClient', + 'TextServiceClient', + 'TextServiceAsyncClient', + 'CitationMetadata', + 'CitationSource', + 'CountMessageTokensRequest', + 'CountMessageTokensResponse', + 'Example', + 'GenerateMessageRequest', + 'GenerateMessageResponse', + 'Message', + 'MessagePrompt', + 'Model', + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'ContentFilter', + 'SafetyFeedback', + 'SafetyRating', + 'SafetySetting', + 'HarmCategory', + 'Embedding', + 'EmbedTextRequest', + 'EmbedTextResponse', + 'GenerateTextRequest', + 'GenerateTextResponse', + 'TextCompletion', + 'TextPrompt', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/gapic_version.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/gapic_version.py new file mode 100644 index 000000000000..360a0d13ebdd --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/gapic_version.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +__version__ = "0.0.0" # {x-release-please-version} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/py.typed b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/py.typed new file mode 100644 index 000000000000..38773eee6363 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-ai-generativelanguage package uses inline types. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py new file mode 100644 index 000000000000..06c40d65931f --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +__version__ = package_version.__version__ + + +from .services.discuss_service import DiscussServiceClient +from .services.discuss_service import DiscussServiceAsyncClient +from .services.model_service import ModelServiceClient +from .services.model_service import ModelServiceAsyncClient +from .services.text_service import TextServiceClient +from .services.text_service import TextServiceAsyncClient + +from .types.citation import CitationMetadata +from .types.citation import CitationSource +from .types.discuss_service import CountMessageTokensRequest +from .types.discuss_service import CountMessageTokensResponse +from .types.discuss_service import Example +from .types.discuss_service import GenerateMessageRequest +from .types.discuss_service import GenerateMessageResponse +from .types.discuss_service import Message +from .types.discuss_service import MessagePrompt +from .types.model import Model +from .types.model_service import GetModelRequest +from .types.model_service import ListModelsRequest +from .types.model_service import ListModelsResponse +from .types.safety import ContentFilter +from .types.safety import SafetyFeedback +from .types.safety import SafetyRating +from .types.safety import SafetySetting +from .types.safety import HarmCategory +from .types.text_service import Embedding +from .types.text_service import EmbedTextRequest +from .types.text_service import EmbedTextResponse +from .types.text_service import GenerateTextRequest +from .types.text_service import GenerateTextResponse +from .types.text_service import TextCompletion +from .types.text_service import TextPrompt + +__all__ = ( + 'DiscussServiceAsyncClient', + 'ModelServiceAsyncClient', + 'TextServiceAsyncClient', +'CitationMetadata', +'CitationSource', +'ContentFilter', +'CountMessageTokensRequest', +'CountMessageTokensResponse', +'DiscussServiceClient', +'EmbedTextRequest', +'EmbedTextResponse', +'Embedding', +'Example', +'GenerateMessageRequest', +'GenerateMessageResponse', +'GenerateTextRequest', +'GenerateTextResponse', +'GetModelRequest', +'HarmCategory', +'ListModelsRequest', +'ListModelsResponse', +'Message', +'MessagePrompt', +'Model', +'ModelServiceClient', +'SafetyFeedback', +'SafetyRating', +'SafetySetting', +'TextCompletion', +'TextPrompt', +'TextServiceClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json new file mode 100644 index 000000000000..e4a6a33d7d90 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json @@ -0,0 +1,156 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.ai.generativelanguage_v1beta2", + "protoPackage": "google.ai.generativelanguage.v1beta2", + "schema": "1.0", + "services": { + "DiscussService": { + "clients": { + "grpc": { + "libraryClient": "DiscussServiceClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + }, + "grpc-async": { + "libraryClient": "DiscussServiceAsyncClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + }, + "rest": { + "libraryClient": "DiscussServiceClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + } + } + }, + "ModelService": { + "clients": { + "grpc": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "GetModel": { + "methods": [ + "get_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ModelServiceAsyncClient", + "rpcs": { + "GetModel": { + "methods": [ + "get_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + } + } + }, + "rest": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "GetModel": { + "methods": [ + "get_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + } + } + } + } + }, + "TextService": { + "clients": { + "grpc": { + "libraryClient": "TextServiceClient", + "rpcs": { + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + }, + "grpc-async": { + "libraryClient": "TextServiceAsyncClient", + "rpcs": { + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + }, + "rest": { + "libraryClient": "TextServiceClient", + "rpcs": { + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + } + } + } + } +} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py new file mode 100644 index 000000000000..360a0d13ebdd --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +__version__ = "0.0.0" # {x-release-please-version} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed new file mode 100644 index 000000000000..38773eee6363 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-ai-generativelanguage package uses inline types. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/__init__.py new file mode 100644 index 000000000000..89a37dc92c5a --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py new file mode 100644 index 000000000000..c5c6e8208269 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import DiscussServiceClient +from .async_client import DiscussServiceAsyncClient + +__all__ = ( + 'DiscussServiceClient', + 'DiscussServiceAsyncClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py new file mode 100644 index 000000000000..b6fbb11900d2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py @@ -0,0 +1,508 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta2.types import discuss_service +from google.ai.generativelanguage_v1beta2.types import safety +from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .client import DiscussServiceClient + + +class DiscussServiceAsyncClient: + """An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + """ + + _client: DiscussServiceClient + + DEFAULT_ENDPOINT = DiscussServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = DiscussServiceClient.DEFAULT_MTLS_ENDPOINT + + model_path = staticmethod(DiscussServiceClient.model_path) + parse_model_path = staticmethod(DiscussServiceClient.parse_model_path) + common_billing_account_path = staticmethod(DiscussServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(DiscussServiceClient.parse_common_billing_account_path) + common_folder_path = staticmethod(DiscussServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(DiscussServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(DiscussServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(DiscussServiceClient.parse_common_organization_path) + common_project_path = staticmethod(DiscussServiceClient.common_project_path) + parse_common_project_path = staticmethod(DiscussServiceClient.parse_common_project_path) + common_location_path = staticmethod(DiscussServiceClient.common_location_path) + parse_common_location_path = staticmethod(DiscussServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceAsyncClient: The constructed client. + """ + return DiscussServiceClient.from_service_account_info.__func__(DiscussServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceAsyncClient: The constructed client. + """ + return DiscussServiceClient.from_service_account_file.__func__(DiscussServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return DiscussServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> DiscussServiceTransport: + """Returns the transport used by the client instance. + + Returns: + DiscussServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(DiscussServiceClient).get_transport_class, type(DiscussServiceClient)) + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, DiscussServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the discuss service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.DiscussServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = DiscussServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def generate_message(self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.GenerateMessageResponse: + r"""Generates a response from the model given an input + ``MessagePrompt``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + async def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_message(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest, dict]]): + The request object. Request to generate a message + response from the model. + model (:class:`str`): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1beta2.types.MessagePrompt`): + Required. The structured textual + input given to the model as a prompt. + Given a + prompt, the model will return what it + predicts is the next message in the + discussion. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (:class:`float`): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (:class:`int`): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If + unset, this will default to ``1``. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (:class:`float`): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens + whose probability sum is at least ``top_p``. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (:class:`int`): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = discuss_service.GenerateMessageRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.generate_message, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def count_message_tokens(self, + request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.CountMessageTokensResponse: + r"""Runs a model's tokenizer on a string and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + async def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_message_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest, dict]]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1beta2.types.MessagePrompt`): + Required. The prompt, whose token + count is to be returned. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse: + A response from CountMessageTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = discuss_service.CountMessageTokensRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.count_message_tokens, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "DiscussServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "DiscussServiceAsyncClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py new file mode 100644 index 000000000000..6301bfd36a15 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py @@ -0,0 +1,712 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta2.types import discuss_service +from google.ai.generativelanguage_v1beta2.types import safety +from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import DiscussServiceGrpcTransport +from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .transports.rest import DiscussServiceRestTransport + + +class DiscussServiceClientMeta(type): + """Metaclass for the DiscussService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] + _transport_registry["grpc"] = DiscussServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DiscussServiceGrpcAsyncIOTransport + _transport_registry["rest"] = DiscussServiceRestTransport + + def get_transport_class(cls, + label: Optional[str] = None, + ) -> Type[DiscussServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class DiscussServiceClient(metaclass=DiscussServiceClientMeta): + """An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> DiscussServiceTransport: + """Returns the transport used by the client instance. + + Returns: + DiscussServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path(model: str,) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format(model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, DiscussServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the discuss service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, DiscussServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError("client_options.api_key and credentials are mutually exclusive") + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, DiscussServiceTransport): + # transport is a DiscussServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError("When providing a transport instance, " + "provide its credentials directly.") + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): + credentials = google.auth._default.get_api_key_credentials(api_key_value) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def generate_message(self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.GenerateMessageResponse: + r"""Generates a response from the model given an input + ``MessagePrompt``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_message(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest, dict]): + The request object. Request to generate a message + response from the model. + model (str): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): + Required. The structured textual + input given to the model as a prompt. + Given a + prompt, the model will return what it + predicts is the next message in the + discussion. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (int): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If + unset, this will default to ``1``. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (float): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens + whose probability sum is at least ``top_p``. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a discuss_service.GenerateMessageRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, discuss_service.GenerateMessageRequest): + request = discuss_service.GenerateMessageRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_message] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def count_message_tokens(self, + request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.CountMessageTokensResponse: + r"""Runs a model's tokenizer on a string and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_message_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest, dict]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): + Required. The prompt, whose token + count is to be returned. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse: + A response from CountMessageTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a discuss_service.CountMessageTokensRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, discuss_service.CountMessageTokensRequest): + request = discuss_service.CountMessageTokensRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.count_message_tokens] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "DiscussServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + + + + + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "DiscussServiceClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/__init__.py new file mode 100644 index 000000000000..b585c1ce424c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import DiscussServiceTransport +from .grpc import DiscussServiceGrpcTransport +from .grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .rest import DiscussServiceRestTransport +from .rest import DiscussServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] +_transport_registry['grpc'] = DiscussServiceGrpcTransport +_transport_registry['grpc_asyncio'] = DiscussServiceGrpcAsyncIOTransport +_transport_registry['rest'] = DiscussServiceRestTransport + +__all__ = ( + 'DiscussServiceTransport', + 'DiscussServiceGrpcTransport', + 'DiscussServiceGrpcAsyncIOTransport', + 'DiscussServiceRestTransport', + 'DiscussServiceRestInterceptor', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py new file mode 100644 index 000000000000..c7d8455ba342 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta2.types import discuss_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +class DiscussServiceTransport(abc.ABC): + """Abstract transport class for DiscussService.""" + + AUTH_SCOPES = ( + ) + + DEFAULT_HOST: str = 'generativelanguage.googleapis.com' + def __init__( + self, *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, + **scopes_kwargs, + quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.generate_message: gapic_v1.method.wrap_method( + self.generate_message, + default_timeout=None, + client_info=client_info, + ), + self.count_message_tokens: gapic_v1.method.wrap_method( + self.count_message_tokens, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + Union[ + discuss_service.GenerateMessageResponse, + Awaitable[discuss_service.GenerateMessageResponse] + ]]: + raise NotImplementedError() + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Union[ + discuss_service.CountMessageTokensResponse, + Awaitable[discuss_service.CountMessageTokensResponse] + ]]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ( + 'DiscussServiceTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py new file mode 100644 index 000000000000..7fc2d1e9779c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.ai.generativelanguage_v1beta2.types import discuss_service +from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO + + +class DiscussServiceGrpcTransport(DiscussServiceTransport): + """gRPC backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse]: + r"""Return a callable for the generate message method over gRPC. + + Generates a response from the model given an input + ``MessagePrompt``. + + Returns: + Callable[[~.GenerateMessageRequest], + ~.GenerateMessageResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_message' not in self._stubs: + self._stubs['generate_message'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.DiscussService/GenerateMessage', + request_serializer=discuss_service.GenerateMessageRequest.serialize, + response_deserializer=discuss_service.GenerateMessageResponse.deserialize, + ) + return self._stubs['generate_message'] + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse]: + r"""Return a callable for the count message tokens method over gRPC. + + Runs a model's tokenizer on a string and returns the + token count. + + Returns: + Callable[[~.CountMessageTokensRequest], + ~.CountMessageTokensResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'count_message_tokens' not in self._stubs: + self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.DiscussService/CountMessageTokens', + request_serializer=discuss_service.CountMessageTokensRequest.serialize, + response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, + ) + return self._stubs['count_message_tokens'] + + def close(self): + self.grpc_channel.close() + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ( + 'DiscussServiceGrpcTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..97e6d426fc5c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.ai.generativelanguage_v1beta2.types import discuss_service +from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import DiscussServiceGrpcTransport + + +class DiscussServiceGrpcAsyncIOTransport(DiscussServiceTransport): + """gRPC AsyncIO backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + Awaitable[discuss_service.GenerateMessageResponse]]: + r"""Return a callable for the generate message method over gRPC. + + Generates a response from the model given an input + ``MessagePrompt``. + + Returns: + Callable[[~.GenerateMessageRequest], + Awaitable[~.GenerateMessageResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_message' not in self._stubs: + self._stubs['generate_message'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.DiscussService/GenerateMessage', + request_serializer=discuss_service.GenerateMessageRequest.serialize, + response_deserializer=discuss_service.GenerateMessageResponse.deserialize, + ) + return self._stubs['generate_message'] + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Awaitable[discuss_service.CountMessageTokensResponse]]: + r"""Return a callable for the count message tokens method over gRPC. + + Runs a model's tokenizer on a string and returns the + token count. + + Returns: + Callable[[~.CountMessageTokensRequest], + Awaitable[~.CountMessageTokensResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'count_message_tokens' not in self._stubs: + self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.DiscussService/CountMessageTokens', + request_serializer=discuss_service.CountMessageTokensRequest.serialize, + response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, + ) + return self._stubs['count_message_tokens'] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ( + 'DiscussServiceGrpcAsyncIOTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py new file mode 100644 index 000000000000..fd68266db64d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py @@ -0,0 +1,433 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.ai.generativelanguage_v1beta2.types import discuss_service + +from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class DiscussServiceRestInterceptor: + """Interceptor for DiscussService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the DiscussServiceRestTransport. + + .. code-block:: python + class MyCustomDiscussServiceInterceptor(DiscussServiceRestInterceptor): + def pre_count_message_tokens(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_count_message_tokens(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_message(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_message(self, response): + logging.log(f"Received response: {response}") + return response + + transport = DiscussServiceRestTransport(interceptor=MyCustomDiscussServiceInterceptor()) + client = DiscussServiceClient(transport=transport) + + + """ + def pre_count_message_tokens(self, request: discuss_service.CountMessageTokensRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.CountMessageTokensRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for count_message_tokens + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_count_message_tokens(self, response: discuss_service.CountMessageTokensResponse) -> discuss_service.CountMessageTokensResponse: + """Post-rpc interceptor for count_message_tokens + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + def pre_generate_message(self, request: discuss_service.GenerateMessageRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.GenerateMessageRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for generate_message + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_generate_message(self, response: discuss_service.GenerateMessageResponse) -> discuss_service.GenerateMessageResponse: + """Post-rpc interceptor for generate_message + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class DiscussServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: DiscussServiceRestInterceptor + + +class DiscussServiceRestTransport(DiscussServiceTransport): + """REST backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[ + ], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = 'https', + interceptor: Optional[DiscussServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or DiscussServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CountMessageTokens(DiscussServiceRestStub): + def __hash__(self): + return hash("CountMessageTokens") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: discuss_service.CountMessageTokensRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> discuss_service.CountMessageTokensResponse: + r"""Call the count message tokens method over HTTP. + + Args: + request (~.discuss_service.CountMessageTokensRequest): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.discuss_service.CountMessageTokensResponse: + A response from ``CountMessageTokens``. + + It returns the model's ``token_count`` for the + ``prompt``. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta2/{model=models/*}:countMessageTokens', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_count_message_tokens(request, metadata) + pb_request = discuss_service.CountMessageTokensRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = discuss_service.CountMessageTokensResponse() + pb_resp = discuss_service.CountMessageTokensResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_count_message_tokens(resp) + return resp + + class _GenerateMessage(DiscussServiceRestStub): + def __hash__(self): + return hash("GenerateMessage") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: discuss_service.GenerateMessageRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> discuss_service.GenerateMessageResponse: + r"""Call the generate message method over HTTP. + + Args: + request (~.discuss_service.GenerateMessageRequest): + The request object. Request to generate a message + response from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.discuss_service.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta2/{model=models/*}:generateMessage', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_generate_message(request, metadata) + pb_request = discuss_service.GenerateMessageRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = discuss_service.GenerateMessageResponse() + pb_resp = discuss_service.GenerateMessageResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_generate_message(resp) + return resp + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CountMessageTokens(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateMessage(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__=( + 'DiscussServiceRestTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py new file mode 100644 index 000000000000..2c368b92d844 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import ModelServiceClient +from .async_client import ModelServiceAsyncClient + +__all__ = ( + 'ModelServiceClient', + 'ModelServiceAsyncClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py new file mode 100644 index 000000000000..4710e8d992c2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py @@ -0,0 +1,431 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta2.services.model_service import pagers +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .client import ModelServiceClient + + +class ModelServiceAsyncClient: + """Provides methods for getting metadata information about + Generative Models. + """ + + _client: ModelServiceClient + + DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT + + model_path = staticmethod(ModelServiceClient.model_path) + parse_model_path = staticmethod(ModelServiceClient.parse_model_path) + common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) + common_folder_path = staticmethod(ModelServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(ModelServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) + common_project_path = staticmethod(ModelServiceClient.common_project_path) + parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) + common_location_path = staticmethod(ModelServiceClient.common_location_path) + parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_info.__func__(ModelServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_file.__func__(ModelServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return ModelServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> ModelServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = ModelServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def get_model(self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets information about a specific Model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + async def sample_get_model(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.GetModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta2.types.GetModelRequest, dict]]): + The request object. Request for getting information about + a specific Model. + name (:class:`str`): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.Model: + Information about a Generative + Language Model. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.GetModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_models(self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: + r"""Lists models available through the API. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + async def sample_list_models(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta2.types.ListModelsRequest, dict]]): + The request object. Request for listing all Models. + page_size (:class:`int`): + The maximum number of ``Models`` to return (per page). + + The service may return fewer models. If unspecified, at + most 50 models will be returned per page. This method + returns at most 1000 models per page, even if you pass a + larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (:class:`str`): + A page token, received from a previous ``ListModels`` + call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the + page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsAsyncPager: + Response from ListModel containing a paginated list of + Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.ListModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_models, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "ModelServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "ModelServiceAsyncClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py new file mode 100644 index 000000000000..9bcf43c759e5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py @@ -0,0 +1,635 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta2.services.model_service import pagers +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import ModelServiceGrpcTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .transports.rest import ModelServiceRestTransport + + +class ModelServiceClientMeta(type): + """Metaclass for the ModelService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport + _transport_registry["rest"] = ModelServiceRestTransport + + def get_transport_class(cls, + label: Optional[str] = None, + ) -> Type[ModelServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class ModelServiceClient(metaclass=ModelServiceClientMeta): + """Provides methods for getting metadata information about + Generative Models. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path(model: str,) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format(model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, ModelServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError("client_options.api_key and credentials are mutually exclusive") + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, ModelServiceTransport): + # transport is a ModelServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError("When providing a transport instance, " + "provide its credentials directly.") + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): + credentials = google.auth._default.get_api_key_credentials(api_key_value) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def get_model(self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets information about a specific Model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + def sample_get_model(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.GetModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta2.types.GetModelRequest, dict]): + The request object. Request for getting information about + a specific Model. + name (str): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.Model: + Information about a Generative + Language Model. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_models(self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: + r"""Lists models available through the API. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + def sample_list_models(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta2.types.ListModelsRequest, dict]): + The request object. Request for listing all Models. + page_size (int): + The maximum number of ``Models`` to return (per page). + + The service may return fewer models. If unspecified, at + most 50 models will be returned per page. This method + returns at most 1000 models per page, even if you pass a + larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (str): + A page token, received from a previous ``ListModels`` + call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the + page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsPager: + Response from ListModel containing a paginated list of + Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_models] + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListModelsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "ModelServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + + + + + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "ModelServiceClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py new file mode 100644 index 000000000000..2183050a4126 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, AsyncIterator, Awaitable, Callable, Sequence, Tuple, Optional, Iterator + +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service + + +class ListModelsPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta2.types.ListModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta2.types.ListModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[model.Model]: + for page in self.pages: + yield from page.models + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListModelsAsyncPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta2.types.ListModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta2.types.ListModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + def __aiter__(self) -> AsyncIterator[model.Model]: + async def async_generator(): + async for page in self.pages: + for response in page.models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/__init__.py new file mode 100644 index 000000000000..c51cadf4ba09 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import ModelServiceTransport +from .grpc import ModelServiceGrpcTransport +from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .rest import ModelServiceRestTransport +from .rest import ModelServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] +_transport_registry['grpc'] = ModelServiceGrpcTransport +_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport +_transport_registry['rest'] = ModelServiceRestTransport + +__all__ = ( + 'ModelServiceTransport', + 'ModelServiceGrpcTransport', + 'ModelServiceGrpcAsyncIOTransport', + 'ModelServiceRestTransport', + 'ModelServiceRestInterceptor', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py new file mode 100644 index 000000000000..3f41738067e8 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +class ModelServiceTransport(abc.ABC): + """Abstract transport class for ModelService.""" + + AUTH_SCOPES = ( + ) + + DEFAULT_HOST: str = 'generativelanguage.googleapis.com' + def __init__( + self, *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, + **scopes_kwargs, + quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.get_model: gapic_v1.method.wrap_method( + self.get_model, + default_timeout=None, + client_info=client_info, + ), + self.list_models: gapic_v1.method.wrap_method( + self.list_models, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + Union[ + model.Model, + Awaitable[model.Model] + ]]: + raise NotImplementedError() + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + Union[ + model_service.ListModelsResponse, + Awaitable[model_service.ListModelsResponse] + ]]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ( + 'ModelServiceTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py new file mode 100644 index 000000000000..892193957a68 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO + + +class ModelServiceGrpcTransport(ModelServiceTransport): + """gRPC backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + model.Model]: + r"""Return a callable for the get model method over gRPC. + + Gets information about a specific Model. + + Returns: + Callable[[~.GetModelRequest], + ~.Model]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.ModelService/GetModel', + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs['get_model'] + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + model_service.ListModelsResponse]: + r"""Return a callable for the list models method over gRPC. + + Lists models available through the API. + + Returns: + Callable[[~.ListModelsRequest], + ~.ListModelsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.ModelService/ListModels', + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs['list_models'] + + def close(self): + self.grpc_channel.close() + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ( + 'ModelServiceGrpcTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..49b3a42dee4c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import ModelServiceGrpcTransport + + +class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): + """gRPC AsyncIO backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + Awaitable[model.Model]]: + r"""Return a callable for the get model method over gRPC. + + Gets information about a specific Model. + + Returns: + Callable[[~.GetModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.ModelService/GetModel', + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs['get_model'] + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + Awaitable[model_service.ListModelsResponse]]: + r"""Return a callable for the list models method over gRPC. + + Lists models available through the API. + + Returns: + Callable[[~.ListModelsRequest], + Awaitable[~.ListModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.ModelService/ListModels', + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs['list_models'] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ( + 'ModelServiceGrpcAsyncIOTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py new file mode 100644 index 000000000000..db28ab8b2b81 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py @@ -0,0 +1,397 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service + +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class ModelServiceRestInterceptor: + """Interceptor for ModelService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the ModelServiceRestTransport. + + .. code-block:: python + class MyCustomModelServiceInterceptor(ModelServiceRestInterceptor): + def pre_get_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_model(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_models(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_models(self, response): + logging.log(f"Received response: {response}") + return response + + transport = ModelServiceRestTransport(interceptor=MyCustomModelServiceInterceptor()) + client = ModelServiceClient(transport=transport) + + + """ + def pre_get_model(self, request: model_service.GetModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.GetModelRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_get_model(self, response: model.Model) -> model.Model: + """Post-rpc interceptor for get_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + def pre_list_models(self, request: model_service.ListModelsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.ListModelsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_models + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_list_models(self, response: model_service.ListModelsResponse) -> model_service.ListModelsResponse: + """Post-rpc interceptor for list_models + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class ModelServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: ModelServiceRestInterceptor + + +class ModelServiceRestTransport(ModelServiceTransport): + """REST backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[ + ], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = 'https', + interceptor: Optional[ModelServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or ModelServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _GetModel(ModelServiceRestStub): + def __hash__(self): + return hash("GetModel") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: model_service.GetModelRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> model.Model: + r"""Call the get model method over HTTP. + + Args: + request (~.model_service.GetModelRequest): + The request object. Request for getting information about + a specific Model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model.Model: + Information about a Generative + Language Model. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta2/{name=models/*}', + }, + ] + request, metadata = self._interceptor.pre_get_model(request, metadata) + pb_request = model_service.GetModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model.Model() + pb_resp = model.Model.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_model(resp) + return resp + + class _ListModels(ModelServiceRestStub): + def __hash__(self): + return hash("ListModels") + + def __call__(self, + request: model_service.ListModelsRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> model_service.ListModelsResponse: + r"""Call the list models method over HTTP. + + Args: + request (~.model_service.ListModelsRequest): + The request object. Request for listing all Models. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model_service.ListModelsResponse: + Response from ``ListModel`` containing a paginated list + of Models. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta2/models', + }, + ] + request, metadata = self._interceptor.pre_list_models(request, metadata) + pb_request = model_service.ListModelsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model_service.ListModelsResponse() + pb_resp = model_service.ListModelsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_models(resp) + return resp + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + model.Model]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + model_service.ListModelsResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListModels(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__=( + 'ModelServiceRestTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/__init__.py new file mode 100644 index 000000000000..f167a9c3175d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import TextServiceClient +from .async_client import TextServiceAsyncClient + +__all__ = ( + 'TextServiceClient', + 'TextServiceAsyncClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py new file mode 100644 index 000000000000..a063956d2782 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py @@ -0,0 +1,514 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta2.types import safety +from google.ai.generativelanguage_v1beta2.types import text_service +from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .client import TextServiceClient + + +class TextServiceAsyncClient: + """API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + """ + + _client: TextServiceClient + + DEFAULT_ENDPOINT = TextServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = TextServiceClient.DEFAULT_MTLS_ENDPOINT + + model_path = staticmethod(TextServiceClient.model_path) + parse_model_path = staticmethod(TextServiceClient.parse_model_path) + common_billing_account_path = staticmethod(TextServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(TextServiceClient.parse_common_billing_account_path) + common_folder_path = staticmethod(TextServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(TextServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(TextServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(TextServiceClient.parse_common_organization_path) + common_project_path = staticmethod(TextServiceClient.common_project_path) + parse_common_project_path = staticmethod(TextServiceClient.parse_common_project_path) + common_location_path = staticmethod(TextServiceClient.common_location_path) + parse_common_location_path = staticmethod(TextServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceAsyncClient: The constructed client. + """ + return TextServiceClient.from_service_account_info.__func__(TextServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceAsyncClient: The constructed client. + """ + return TextServiceClient.from_service_account_file.__func__(TextServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return TextServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> TextServiceTransport: + """Returns the transport used by the client instance. + + Returns: + TextServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(TextServiceClient).get_transport_class, type(TextServiceClient)) + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, TextServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the text service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.TextServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = TextServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def generate_text(self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.GenerateTextResponse: + r"""Generates a response from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + async def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta2.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta2.types.GenerateTextRequest, dict]]): + The request object. Request to generate a text completion + response from the model. + model (:class:`str`): + Required. The model name to use with + the format name=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1beta2.types.TextPrompt`): + Required. The free-form input text + given to the model as a prompt. + Given a prompt, the model will generate + a TextCompletion response it predicts as + the completion of the input text. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (:class:`float`): + Controls the randomness of the output. Note: The default + value varies by model, see the ``Model.temperature`` + attribute of the ``Model`` returned the ``getModel`` + function. + + Values can range from [0.0,1.0], inclusive. A value + closer to 1.0 will produce responses that are more + varied and creative, while a value closer to 0.0 will + typically result in more straightforward responses from + the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (:class:`int`): + Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, + this will default to 1. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + max_output_tokens (:class:`int`): + The maximum number of tokens to + include in a candidate. + If unset, this will default to 64. + + This corresponds to the ``max_output_tokens`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (:class:`float`): + The maximum cumulative probability of tokens to consider + when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities + so that only the most liekly tokens are considered. + Top-k sampling directly limits the maximum number of + tokens to consider, while Nucleus sampling limits number + of tokens based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (:class:`int`): + The maximum number of tokens to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = text_service.GenerateTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if max_output_tokens is not None: + request.max_output_tokens = max_output_tokens + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.generate_text, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def embed_text(self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.EmbedTextResponse: + r"""Generates an embedding from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + async def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = await client.embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta2.types.EmbedTextRequest, dict]]): + The request object. Request to get a text embedding from + the model. + model (:class:`str`): + Required. The model name to use with + the format model=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + text (:class:`str`): + Required. The free-form input text + that the model will turn into an + embedding. + + This corresponds to the ``text`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, text]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = text_service.EmbedTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if text is not None: + request.text = text + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.embed_text, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "TextServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "TextServiceAsyncClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py new file mode 100644 index 000000000000..39ecd7327b22 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py @@ -0,0 +1,718 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta2.types import safety +from google.ai.generativelanguage_v1beta2.types import text_service +from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import TextServiceGrpcTransport +from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .transports.rest import TextServiceRestTransport + + +class TextServiceClientMeta(type): + """Metaclass for the TextService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] + _transport_registry["grpc"] = TextServiceGrpcTransport + _transport_registry["grpc_asyncio"] = TextServiceGrpcAsyncIOTransport + _transport_registry["rest"] = TextServiceRestTransport + + def get_transport_class(cls, + label: Optional[str] = None, + ) -> Type[TextServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class TextServiceClient(metaclass=TextServiceClientMeta): + """API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> TextServiceTransport: + """Returns the transport used by the client instance. + + Returns: + TextServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path(model: str,) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format(model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, TextServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the text service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, TextServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError("client_options.api_key and credentials are mutually exclusive") + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, TextServiceTransport): + # transport is a TextServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError("When providing a transport instance, " + "provide its credentials directly.") + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): + credentials = google.auth._default.get_api_key_credentials(api_key_value) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def generate_text(self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.GenerateTextResponse: + r"""Generates a response from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta2.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta2.types.GenerateTextRequest, dict]): + The request object. Request to generate a text completion + response from the model. + model (str): + Required. The model name to use with + the format name=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1beta2.types.TextPrompt): + Required. The free-form input text + given to the model as a prompt. + Given a prompt, the model will generate + a TextCompletion response it predicts as + the completion of the input text. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (float): + Controls the randomness of the output. Note: The default + value varies by model, see the ``Model.temperature`` + attribute of the ``Model`` returned the ``getModel`` + function. + + Values can range from [0.0,1.0], inclusive. A value + closer to 1.0 will produce responses that are more + varied and creative, while a value closer to 0.0 will + typically result in more straightforward responses from + the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (int): + Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, + this will default to 1. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + max_output_tokens (int): + The maximum number of tokens to + include in a candidate. + If unset, this will default to 64. + + This corresponds to the ``max_output_tokens`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (float): + The maximum cumulative probability of tokens to consider + when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities + so that only the most liekly tokens are considered. + Top-k sampling directly limits the maximum number of + tokens to consider, while Nucleus sampling limits number + of tokens based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (int): + The maximum number of tokens to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a text_service.GenerateTextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, text_service.GenerateTextRequest): + request = text_service.GenerateTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if max_output_tokens is not None: + request.max_output_tokens = max_output_tokens + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def embed_text(self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.EmbedTextResponse: + r"""Generates an embedding from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta2 + + def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = client.embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta2.types.EmbedTextRequest, dict]): + The request object. Request to get a text embedding from + the model. + model (str): + Required. The model name to use with + the format model=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + text (str): + Required. The free-form input text + that the model will turn into an + embedding. + + This corresponds to the ``text`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta2.types.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, text]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a text_service.EmbedTextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, text_service.EmbedTextRequest): + request = text_service.EmbedTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if text is not None: + request.text = text + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.embed_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "TextServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + + + + + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "TextServiceClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/__init__.py new file mode 100644 index 000000000000..71e949c7a4f5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import TextServiceTransport +from .grpc import TextServiceGrpcTransport +from .grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .rest import TextServiceRestTransport +from .rest import TextServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] +_transport_registry['grpc'] = TextServiceGrpcTransport +_transport_registry['grpc_asyncio'] = TextServiceGrpcAsyncIOTransport +_transport_registry['rest'] = TextServiceRestTransport + +__all__ = ( + 'TextServiceTransport', + 'TextServiceGrpcTransport', + 'TextServiceGrpcAsyncIOTransport', + 'TextServiceRestTransport', + 'TextServiceRestInterceptor', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py new file mode 100644 index 000000000000..b038dec99299 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.ai.generativelanguage_v1beta2 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta2.types import text_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +class TextServiceTransport(abc.ABC): + """Abstract transport class for TextService.""" + + AUTH_SCOPES = ( + ) + + DEFAULT_HOST: str = 'generativelanguage.googleapis.com' + def __init__( + self, *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, + **scopes_kwargs, + quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.generate_text: gapic_v1.method.wrap_method( + self.generate_text, + default_timeout=None, + client_info=client_info, + ), + self.embed_text: gapic_v1.method.wrap_method( + self.embed_text, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + Union[ + text_service.GenerateTextResponse, + Awaitable[text_service.GenerateTextResponse] + ]]: + raise NotImplementedError() + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + Union[ + text_service.EmbedTextResponse, + Awaitable[text_service.EmbedTextResponse] + ]]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ( + 'TextServiceTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py new file mode 100644 index 000000000000..4835582937e6 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.ai.generativelanguage_v1beta2.types import text_service +from .base import TextServiceTransport, DEFAULT_CLIENT_INFO + + +class TextServiceGrpcTransport(TextServiceTransport): + """gRPC backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + text_service.GenerateTextResponse]: + r"""Return a callable for the generate text method over gRPC. + + Generates a response from the model given an input + message. + + Returns: + Callable[[~.GenerateTextRequest], + ~.GenerateTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_text' not in self._stubs: + self._stubs['generate_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.TextService/GenerateText', + request_serializer=text_service.GenerateTextRequest.serialize, + response_deserializer=text_service.GenerateTextResponse.deserialize, + ) + return self._stubs['generate_text'] + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + text_service.EmbedTextResponse]: + r"""Return a callable for the embed text method over gRPC. + + Generates an embedding from the model given an input + message. + + Returns: + Callable[[~.EmbedTextRequest], + ~.EmbedTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'embed_text' not in self._stubs: + self._stubs['embed_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.TextService/EmbedText', + request_serializer=text_service.EmbedTextRequest.serialize, + response_deserializer=text_service.EmbedTextResponse.deserialize, + ) + return self._stubs['embed_text'] + + def close(self): + self.grpc_channel.close() + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ( + 'TextServiceGrpcTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..8a8cdeeda949 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py @@ -0,0 +1,294 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.ai.generativelanguage_v1beta2.types import text_service +from .base import TextServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import TextServiceGrpcTransport + + +class TextServiceGrpcAsyncIOTransport(TextServiceTransport): + """gRPC AsyncIO backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + Awaitable[text_service.GenerateTextResponse]]: + r"""Return a callable for the generate text method over gRPC. + + Generates a response from the model given an input + message. + + Returns: + Callable[[~.GenerateTextRequest], + Awaitable[~.GenerateTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_text' not in self._stubs: + self._stubs['generate_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.TextService/GenerateText', + request_serializer=text_service.GenerateTextRequest.serialize, + response_deserializer=text_service.GenerateTextResponse.deserialize, + ) + return self._stubs['generate_text'] + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + Awaitable[text_service.EmbedTextResponse]]: + r"""Return a callable for the embed text method over gRPC. + + Generates an embedding from the model given an input + message. + + Returns: + Callable[[~.EmbedTextRequest], + Awaitable[~.EmbedTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'embed_text' not in self._stubs: + self._stubs['embed_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta2.TextService/EmbedText', + request_serializer=text_service.EmbedTextRequest.serialize, + response_deserializer=text_service.EmbedTextResponse.deserialize, + ) + return self._stubs['embed_text'] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ( + 'TextServiceGrpcAsyncIOTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py new file mode 100644 index 000000000000..2480e0dd0389 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.ai.generativelanguage_v1beta2.types import text_service + +from .base import TextServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class TextServiceRestInterceptor: + """Interceptor for TextService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the TextServiceRestTransport. + + .. code-block:: python + class MyCustomTextServiceInterceptor(TextServiceRestInterceptor): + def pre_embed_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_embed_text(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_text(self, response): + logging.log(f"Received response: {response}") + return response + + transport = TextServiceRestTransport(interceptor=MyCustomTextServiceInterceptor()) + client = TextServiceClient(transport=transport) + + + """ + def pre_embed_text(self, request: text_service.EmbedTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.EmbedTextRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for embed_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_embed_text(self, response: text_service.EmbedTextResponse) -> text_service.EmbedTextResponse: + """Post-rpc interceptor for embed_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + def pre_generate_text(self, request: text_service.GenerateTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.GenerateTextRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for generate_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_generate_text(self, response: text_service.GenerateTextResponse) -> text_service.GenerateTextResponse: + """Post-rpc interceptor for generate_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class TextServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: TextServiceRestInterceptor + + +class TextServiceRestTransport(TextServiceTransport): + """REST backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[ + ], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = 'https', + interceptor: Optional[TextServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or TextServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _EmbedText(TextServiceRestStub): + def __hash__(self): + return hash("EmbedText") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: text_service.EmbedTextRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> text_service.EmbedTextResponse: + r"""Call the embed text method over HTTP. + + Args: + request (~.text_service.EmbedTextRequest): + The request object. Request to get a text embedding from + the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.text_service.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta2/{model=models/*}:embedText', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_embed_text(request, metadata) + pb_request = text_service.EmbedTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.EmbedTextResponse() + pb_resp = text_service.EmbedTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_embed_text(resp) + return resp + + class _GenerateText(TextServiceRestStub): + def __hash__(self): + return hash("GenerateText") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: text_service.GenerateTextRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> text_service.GenerateTextResponse: + r"""Call the generate text method over HTTP. + + Args: + request (~.text_service.GenerateTextRequest): + The request object. Request to generate a text completion + response from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.text_service.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta2/{model=models/*}:generateText', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_generate_text(request, metadata) + pb_request = text_service.GenerateTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.GenerateTextResponse() + pb_resp = text_service.GenerateTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_generate_text(resp) + return resp + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + text_service.EmbedTextResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._EmbedText(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + text_service.GenerateTextResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateText(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__=( + 'TextServiceRestTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py new file mode 100644 index 000000000000..6f8563368f76 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .citation import ( + CitationMetadata, + CitationSource, +) +from .discuss_service import ( + CountMessageTokensRequest, + CountMessageTokensResponse, + Example, + GenerateMessageRequest, + GenerateMessageResponse, + Message, + MessagePrompt, +) +from .model import ( + Model, +) +from .model_service import ( + GetModelRequest, + ListModelsRequest, + ListModelsResponse, +) +from .safety import ( + ContentFilter, + SafetyFeedback, + SafetyRating, + SafetySetting, + HarmCategory, +) +from .text_service import ( + Embedding, + EmbedTextRequest, + EmbedTextResponse, + GenerateTextRequest, + GenerateTextResponse, + TextCompletion, + TextPrompt, +) + +__all__ = ( + 'CitationMetadata', + 'CitationSource', + 'CountMessageTokensRequest', + 'CountMessageTokensResponse', + 'Example', + 'GenerateMessageRequest', + 'GenerateMessageResponse', + 'Message', + 'MessagePrompt', + 'Model', + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'ContentFilter', + 'SafetyFeedback', + 'SafetyRating', + 'SafetySetting', + 'HarmCategory', + 'Embedding', + 'EmbedTextRequest', + 'EmbedTextResponse', + 'GenerateTextRequest', + 'GenerateTextResponse', + 'TextCompletion', + 'TextPrompt', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py new file mode 100644 index 000000000000..e4ecf054b568 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta2', + manifest={ + 'CitationMetadata', + 'CitationSource', + }, +) + + +class CitationMetadata(proto.Message): + r"""A collection of source attributions for a piece of content. + + Attributes: + citation_sources (MutableSequence[google.ai.generativelanguage_v1beta2.types.CitationSource]): + Citations to sources for a specific response. + """ + + citation_sources: MutableSequence['CitationSource'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='CitationSource', + ) + + +class CitationSource(proto.Message): + r"""A citation to a source for a portion of a specific response. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + start_index (int): + Optional. Start of segment of the response + that is attributed to this source. + + Index indicates the start of the segment, + measured in bytes. + + This field is a member of `oneof`_ ``_start_index``. + end_index (int): + Optional. End of the attributed segment, + exclusive. + + This field is a member of `oneof`_ ``_end_index``. + uri (str): + Optional. URI that is attributed as a source + for a portion of the text. + + This field is a member of `oneof`_ ``_uri``. + license_ (str): + Optional. License for the GitHub project that + is attributed as a source for segment. + + License info is required for code citations. + + This field is a member of `oneof`_ ``_license``. + """ + + start_index: int = proto.Field( + proto.INT32, + number=1, + optional=True, + ) + end_index: int = proto.Field( + proto.INT32, + number=2, + optional=True, + ) + uri: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + license_: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py new file mode 100644 index 000000000000..f91ed5b98bed --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1beta2.types import citation +from google.ai.generativelanguage_v1beta2.types import safety + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta2', + manifest={ + 'GenerateMessageRequest', + 'GenerateMessageResponse', + 'Message', + 'MessagePrompt', + 'Example', + 'CountMessageTokensRequest', + 'CountMessageTokensResponse', + }, +) + + +class GenerateMessageRequest(proto.Message): + r"""Request to generate a message response from the model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): + Required. The structured textual input given + to the model as a prompt. + Given a + prompt, the model will return what it predicts + is the next message in the discussion. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This field is a member of `oneof`_ ``_temperature``. + candidate_count (int): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If unset, + this will default to ``1``. + + This field is a member of `oneof`_ ``_candidate_count``. + top_p (float): + Optional. The maximum cumulative probability of tokens to + consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. + + This field is a member of `oneof`_ ``_top_k``. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: 'MessagePrompt' = proto.Field( + proto.MESSAGE, + number=2, + message='MessagePrompt', + ) + temperature: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + candidate_count: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=5, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=6, + optional=True, + ) + + +class GenerateMessageResponse(proto.Message): + r"""The response from the model. + + This includes candidate messages and + conversation history in the form of chronologically-ordered + messages. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1beta2.types.Message]): + Candidate response messages from the model. + messages (MutableSequence[google.ai.generativelanguage_v1beta2.types.Message]): + The conversation history used by the model. + filters (MutableSequence[google.ai.generativelanguage_v1beta2.types.ContentFilter]): + A set of content filtering metadata for the prompt and + response text. + + This indicates which ``SafetyCategory``\ (s) blocked a + candidate from this response, the lowest ``HarmProbability`` + that triggered a block, and the HarmThreshold setting for + that category. + """ + + candidates: MutableSequence['Message'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='Message', + ) + messages: MutableSequence['Message'] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message='Message', + ) + filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.ContentFilter, + ) + + +class Message(proto.Message): + r"""The base unit of structured text. + + A ``Message`` includes an ``author`` and the ``content`` of the + ``Message``. + + The ``author`` is used to tag messages when they are fed to the + model as text. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + author (str): + Optional. The author of this Message. + + This serves as a key for tagging + the content of this Message when it is fed to + the model as text. + + The author can be any alphanumeric string. + content (str): + Required. The text content of the structured ``Message``. + citation_metadata (google.ai.generativelanguage_v1beta2.types.CitationMetadata): + Output only. Citation information for model-generated + ``content`` in this ``Message``. + + If this ``Message`` was generated as output from the model, + this field may be populated with attribution information for + any text included in the ``content``. This field is used + only on output. + + This field is a member of `oneof`_ ``_citation_metadata``. + """ + + author: str = proto.Field( + proto.STRING, + number=1, + ) + content: str = proto.Field( + proto.STRING, + number=2, + ) + citation_metadata: citation.CitationMetadata = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=citation.CitationMetadata, + ) + + +class MessagePrompt(proto.Message): + r"""All of the structured input text passed to the model as a prompt. + + A ``MessagePrompt`` contains a structured set of fields that provide + context for the conversation, examples of user input/model output + message pairs that prime the model to respond in different ways, and + the conversation history or list of messages representing the + alternating turns of the conversation between the user and the + model. + + Attributes: + context (str): + Optional. Text that should be provided to the model first to + ground the response. + + If not empty, this ``context`` will be given to the model + first before the ``examples`` and ``messages``. When using a + ``context`` be sure to provide it with every request to + maintain continuity. + + This field can be a description of your prompt to the model + to help provide context and guide the responses. Examples: + "Translate the phrase from English to French." or "Given a + statement, classify the sentiment as happy, sad or neutral." + + Anything included in this field will take precedence over + message history if the total input size exceeds the model's + ``input_token_limit`` and the input request is truncated. + examples (MutableSequence[google.ai.generativelanguage_v1beta2.types.Example]): + Optional. Examples of what the model should generate. + + This includes both user input and the response that the + model should emulate. + + These ``examples`` are treated identically to conversation + messages except that they take precedence over the history + in ``messages``: If the total input size exceeds the model's + ``input_token_limit`` the input will be truncated. Items + will be dropped from ``messages`` before ``examples``. + messages (MutableSequence[google.ai.generativelanguage_v1beta2.types.Message]): + Required. A snapshot of the recent conversation history + sorted chronologically. + + Turns alternate between two authors. + + If the total input size exceeds the model's + ``input_token_limit`` the input will be truncated: The + oldest items will be dropped from ``messages``. + """ + + context: str = proto.Field( + proto.STRING, + number=1, + ) + examples: MutableSequence['Example'] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message='Example', + ) + messages: MutableSequence['Message'] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message='Message', + ) + + +class Example(proto.Message): + r"""An input/output example used to instruct the Model. + + It demonstrates how the model should respond or format its + response. + + Attributes: + input (google.ai.generativelanguage_v1beta2.types.Message): + Required. An example of an input ``Message`` from the user. + output (google.ai.generativelanguage_v1beta2.types.Message): + Required. An example of what the model should + output given the input. + """ + + input: 'Message' = proto.Field( + proto.MESSAGE, + number=1, + message='Message', + ) + output: 'Message' = proto.Field( + proto.MESSAGE, + number=2, + message='Message', + ) + + +class CountMessageTokensRequest(proto.Message): + r"""Counts the number of tokens in the ``prompt`` sent to a model. + + Models may tokenize text differently, so each model may return a + different ``token_count``. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): + Required. The prompt, whose token count is to + be returned. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: 'MessagePrompt' = proto.Field( + proto.MESSAGE, + number=2, + message='MessagePrompt', + ) + + +class CountMessageTokensResponse(proto.Message): + r"""A response from ``CountMessageTokens``. + + It returns the model's ``token_count`` for the ``prompt``. + + Attributes: + token_count (int): + The number of tokens that the ``model`` tokenizes the + ``prompt`` into. + + Always non-negative. + """ + + token_count: int = proto.Field( + proto.INT32, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py new file mode 100644 index 000000000000..d1698c736311 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta2', + manifest={ + 'Model', + }, +) + + +class Model(proto.Message): + r"""Information about a Generative Language Model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + name (str): + Required. The resource name of the ``Model``. + + Format: ``models/{model}`` with a ``{model}`` naming + convention of: + + - "{base_model_id}-{version}" + + Examples: + + - ``models/chat-bison-001`` + base_model_id (str): + Required. The name of the base model, pass this to the + generation request. + + Examples: + + - ``chat-bison`` + version (str): + Required. The version number of the model. + + This represents the major version + display_name (str): + The human-readable name of the model. E.g. + "Chat Bison". + The name can be up to 128 characters long and + can consist of any UTF-8 characters. + description (str): + A short description of the model. + input_token_limit (int): + Maximum number of input tokens allowed for + this model. + output_token_limit (int): + Maximum number of output tokens available for + this model. + supported_generation_methods (MutableSequence[str]): + The model's supported generation methods. + + The method names are defined as Pascal case strings, such as + ``generateMessage`` which correspond to API methods. + temperature (float): + Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. This + value specifies default to be used by the backend while + making the call to the model. + + This field is a member of `oneof`_ ``_temperature``. + top_p (float): + For Nucleus sampling. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. This value specifies + default to be used by the backend while making the call to + the model. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + For Top-k sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. This value specifies default to be used by the + backend while making the call to the model. + + This field is a member of `oneof`_ ``_top_k``. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + base_model_id: str = proto.Field( + proto.STRING, + number=2, + ) + version: str = proto.Field( + proto.STRING, + number=3, + ) + display_name: str = proto.Field( + proto.STRING, + number=4, + ) + description: str = proto.Field( + proto.STRING, + number=5, + ) + input_token_limit: int = proto.Field( + proto.INT32, + number=6, + ) + output_token_limit: int = proto.Field( + proto.INT32, + number=7, + ) + supported_generation_methods: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=8, + ) + temperature: float = proto.Field( + proto.FLOAT, + number=9, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=10, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=11, + optional=True, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py new file mode 100644 index 000000000000..bb10f6ebd82a --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1beta2.types import model + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta2', + manifest={ + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + }, +) + + +class GetModelRequest(proto.Message): + r"""Request for getting information about a specific Model. + + Attributes: + name (str): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListModelsRequest(proto.Message): + r"""Request for listing all Models. + + Attributes: + page_size (int): + The maximum number of ``Models`` to return (per page). + + The service may return fewer models. If unspecified, at most + 50 models will be returned per page. This method returns at + most 1000 models per page, even if you pass a larger + page_size. + page_token (str): + A page token, received from a previous ``ListModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the page + token. + """ + + page_size: int = proto.Field( + proto.INT32, + number=2, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListModelsResponse(proto.Message): + r"""Response from ``ListModel`` containing a paginated list of Models. + + Attributes: + models (MutableSequence[google.ai.generativelanguage_v1beta2.types.Model]): + The returned Models. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. + + If this field is omitted, there are no more pages. + """ + + @property + def raw_page(self): + return self + + models: MutableSequence[model.Model] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=model.Model, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py new file mode 100644 index 000000000000..990acf3f4dd2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta2', + manifest={ + 'HarmCategory', + 'ContentFilter', + 'SafetyFeedback', + 'SafetyRating', + 'SafetySetting', + }, +) + + +class HarmCategory(proto.Enum): + r"""The category of a rating. + + These categories cover various kinds of harms that developers + may wish to adjust. + + Values: + HARM_CATEGORY_UNSPECIFIED (0): + Category is unspecified. + HARM_CATEGORY_DEROGATORY (1): + Negative or harmful comments targeting + identity and/or protected attribute. + HARM_CATEGORY_TOXICITY (2): + Content that is rude, disrepspectful, or + profane. + HARM_CATEGORY_VIOLENCE (3): + Describes scenarios depictng violence against + an individual or group, or general descriptions + of gore. + HARM_CATEGORY_SEXUAL (4): + Contains references to sexual acts or other + lewd content. + HARM_CATEGORY_MEDICAL (5): + Promotes unchecked medical advice. + HARM_CATEGORY_DANGEROUS (6): + Dangerous content that promotes, facilitates, + or encourages harmful acts. + """ + HARM_CATEGORY_UNSPECIFIED = 0 + HARM_CATEGORY_DEROGATORY = 1 + HARM_CATEGORY_TOXICITY = 2 + HARM_CATEGORY_VIOLENCE = 3 + HARM_CATEGORY_SEXUAL = 4 + HARM_CATEGORY_MEDICAL = 5 + HARM_CATEGORY_DANGEROUS = 6 + + +class ContentFilter(proto.Message): + r"""Content filtering metadata associated with processing a + single request. + ContentFilter contains a reason and an optional supporting + string. The reason may be unspecified. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + reason (google.ai.generativelanguage_v1beta2.types.ContentFilter.BlockedReason): + The reason content was blocked during request + processing. + message (str): + A string that describes the filtering + behavior in more detail. + + This field is a member of `oneof`_ ``_message``. + """ + class BlockedReason(proto.Enum): + r"""A list of reasons why content may have been blocked. + + Values: + BLOCKED_REASON_UNSPECIFIED (0): + A blocked reason was not specified. + SAFETY (1): + Content was blocked by safety settings. + OTHER (2): + Content was blocked, but the reason is + uncategorized. + """ + BLOCKED_REASON_UNSPECIFIED = 0 + SAFETY = 1 + OTHER = 2 + + reason: BlockedReason = proto.Field( + proto.ENUM, + number=1, + enum=BlockedReason, + ) + message: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class SafetyFeedback(proto.Message): + r"""Safety feedback for an entire request. + + This field is populated if content in the input and/or response + is blocked due to safety settings. SafetyFeedback may not exist + for every HarmCategory. Each SafetyFeedback will return the + safety settings used by the request as well as the lowest + HarmProbability that should be allowed in order to return a + result. + + Attributes: + rating (google.ai.generativelanguage_v1beta2.types.SafetyRating): + Safety rating evaluated from content. + setting (google.ai.generativelanguage_v1beta2.types.SafetySetting): + Safety settings applied to the request. + """ + + rating: 'SafetyRating' = proto.Field( + proto.MESSAGE, + number=1, + message='SafetyRating', + ) + setting: 'SafetySetting' = proto.Field( + proto.MESSAGE, + number=2, + message='SafetySetting', + ) + + +class SafetyRating(proto.Message): + r"""Safety rating for a piece of content. + + The safety rating contains the category of harm and the harm + probability level in that category for a piece of content. + Content is classified for safety across a number of harm + categories and the probability of the harm classification is + included here. + + Attributes: + category (google.ai.generativelanguage_v1beta2.types.HarmCategory): + Required. The category for this rating. + probability (google.ai.generativelanguage_v1beta2.types.SafetyRating.HarmProbability): + Required. The probability of harm for this + content. + """ + class HarmProbability(proto.Enum): + r"""The probability that a piece of content is harmful. + + The classification system gives the probability of the content + being unsafe. This does not indicate the severity of harm for a + piece of content. + + Values: + HARM_PROBABILITY_UNSPECIFIED (0): + Probability is unspecified. + NEGLIGIBLE (1): + Content has a negligible chance of being + unsafe. + LOW (2): + Content has a low chance of being unsafe. + MEDIUM (3): + Content has a medium chance of being unsafe. + HIGH (4): + Content has a high chance of being unsafe. + """ + HARM_PROBABILITY_UNSPECIFIED = 0 + NEGLIGIBLE = 1 + LOW = 2 + MEDIUM = 3 + HIGH = 4 + + category: 'HarmCategory' = proto.Field( + proto.ENUM, + number=3, + enum='HarmCategory', + ) + probability: HarmProbability = proto.Field( + proto.ENUM, + number=4, + enum=HarmProbability, + ) + + +class SafetySetting(proto.Message): + r"""Safety setting, affecting the safety-blocking behavior. + + Passing a safety setting for a category changes the allowed + proability that content is blocked. + + Attributes: + category (google.ai.generativelanguage_v1beta2.types.HarmCategory): + Required. The category for this setting. + threshold (google.ai.generativelanguage_v1beta2.types.SafetySetting.HarmBlockThreshold): + Required. Controls the probability threshold + at which harm is blocked. + """ + class HarmBlockThreshold(proto.Enum): + r"""Block at and beyond a specified harm probability. + + Values: + HARM_BLOCK_THRESHOLD_UNSPECIFIED (0): + Threshold is unspecified. + BLOCK_LOW_AND_ABOVE (1): + Content with NEGLIGIBLE will be allowed. + BLOCK_MEDIUM_AND_ABOVE (2): + Content with NEGLIGIBLE and LOW will be + allowed. + BLOCK_ONLY_HIGH (3): + Content with NEGLIGIBLE, LOW, and MEDIUM will + be allowed. + """ + HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0 + BLOCK_LOW_AND_ABOVE = 1 + BLOCK_MEDIUM_AND_ABOVE = 2 + BLOCK_ONLY_HIGH = 3 + + category: 'HarmCategory' = proto.Field( + proto.ENUM, + number=3, + enum='HarmCategory', + ) + threshold: HarmBlockThreshold = proto.Field( + proto.ENUM, + number=4, + enum=HarmBlockThreshold, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py new file mode 100644 index 000000000000..572f3c5392b2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1beta2.types import citation +from google.ai.generativelanguage_v1beta2.types import safety + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta2', + manifest={ + 'GenerateTextRequest', + 'GenerateTextResponse', + 'TextPrompt', + 'TextCompletion', + 'EmbedTextRequest', + 'EmbedTextResponse', + 'Embedding', + }, +) + + +class GenerateTextRequest(proto.Message): + r"""Request to generate a text completion response from the + model. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The model name to use with the + format name=models/{model}. + prompt (google.ai.generativelanguage_v1beta2.types.TextPrompt): + Required. The free-form input text given to + the model as a prompt. + Given a prompt, the model will generate a + TextCompletion response it predicts as the + completion of the input text. + temperature (float): + Controls the randomness of the output. Note: The default + value varies by model, see the ``Model.temperature`` + attribute of the ``Model`` returned the ``getModel`` + function. + + Values can range from [0.0,1.0], inclusive. A value closer + to 1.0 will produce responses that are more varied and + creative, while a value closer to 0.0 will typically result + in more straightforward responses from the model. + + This field is a member of `oneof`_ ``_temperature``. + candidate_count (int): + Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, this + will default to 1. + + This field is a member of `oneof`_ ``_candidate_count``. + max_output_tokens (int): + The maximum number of tokens to include in a + candidate. + If unset, this will default to 64. + + This field is a member of `oneof`_ ``_max_output_tokens``. + top_p (float): + The maximum cumulative probability of tokens to consider + when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities so + that only the most liekly tokens are considered. Top-k + sampling directly limits the maximum number of tokens to + consider, while Nucleus sampling limits number of tokens + based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + The maximum number of tokens to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This field is a member of `oneof`_ ``_top_k``. + safety_settings (MutableSequence[google.ai.generativelanguage_v1beta2.types.SafetySetting]): + A list of unique ``SafetySetting`` instances for blocking + unsafe content. + + that will be enforced on the ``GenerateTextRequest.prompt`` + and ``GenerateTextResponse.candidates``. There should not be + more than one setting for each ``SafetyCategory`` type. The + API will block any prompts and responses that fail to meet + the thresholds set by these settings. This list overrides + the default settings for each ``SafetyCategory`` specified + in the safety_settings. If there is no ``SafetySetting`` for + a given ``SafetyCategory`` provided in the list, the API + will use the default safety setting for that category. + stop_sequences (MutableSequence[str]): + The set of character sequences (up to 5) that + will stop output generation. If specified, the + API will stop at the first appearance of a stop + sequence. The stop sequence will not be included + as part of the response. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: 'TextPrompt' = proto.Field( + proto.MESSAGE, + number=2, + message='TextPrompt', + ) + temperature: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + candidate_count: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + max_output_tokens: int = proto.Field( + proto.INT32, + number=5, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=6, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=7, + optional=True, + ) + safety_settings: MutableSequence[safety.SafetySetting] = proto.RepeatedField( + proto.MESSAGE, + number=8, + message=safety.SafetySetting, + ) + stop_sequences: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=9, + ) + + +class GenerateTextResponse(proto.Message): + r"""The response from the model, including candidate completions. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1beta2.types.TextCompletion]): + Candidate responses from the model. + filters (MutableSequence[google.ai.generativelanguage_v1beta2.types.ContentFilter]): + A set of content filtering metadata for the prompt and + response text. + + This indicates which ``SafetyCategory``\ (s) blocked a + candidate from this response, the lowest ``HarmProbability`` + that triggered a block, and the HarmThreshold setting for + that category. This indicates the smallest change to the + ``SafetySettings`` that would be necessary to unblock at + least 1 response. + + The blocking is configured by the ``SafetySettings`` in the + request (or the default ``SafetySettings`` of the API). + safety_feedback (MutableSequence[google.ai.generativelanguage_v1beta2.types.SafetyFeedback]): + Returns any safety feedback related to + content filtering. + """ + + candidates: MutableSequence['TextCompletion'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='TextCompletion', + ) + filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.ContentFilter, + ) + safety_feedback: MutableSequence[safety.SafetyFeedback] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=safety.SafetyFeedback, + ) + + +class TextPrompt(proto.Message): + r"""Text given to the model as a prompt. + + The Model will use this TextPrompt to Generate a text + completion. + + Attributes: + text (str): + Required. The prompt text. + """ + + text: str = proto.Field( + proto.STRING, + number=1, + ) + + +class TextCompletion(proto.Message): + r"""Output text returned from a model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + output (str): + Output only. The generated text returned from + the model. + safety_ratings (MutableSequence[google.ai.generativelanguage_v1beta2.types.SafetyRating]): + Ratings for the safety of a response. + + There is at most one rating per category. + citation_metadata (google.ai.generativelanguage_v1beta2.types.CitationMetadata): + Output only. Citation information for model-generated + ``output`` in this ``TextCompletion``. + + This field may be populated with attribution information for + any text included in the ``output``. + + This field is a member of `oneof`_ ``_citation_metadata``. + """ + + output: str = proto.Field( + proto.STRING, + number=1, + ) + safety_ratings: MutableSequence[safety.SafetyRating] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=safety.SafetyRating, + ) + citation_metadata: citation.CitationMetadata = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=citation.CitationMetadata, + ) + + +class EmbedTextRequest(proto.Message): + r"""Request to get a text embedding from the model. + + Attributes: + model (str): + Required. The model name to use with the + format model=models/{model}. + text (str): + Required. The free-form input text that the + model will turn into an embedding. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + text: str = proto.Field( + proto.STRING, + number=2, + ) + + +class EmbedTextResponse(proto.Message): + r"""The response to a EmbedTextRequest. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + embedding (google.ai.generativelanguage_v1beta2.types.Embedding): + Output only. The embedding generated from the + input text. + + This field is a member of `oneof`_ ``_embedding``. + """ + + embedding: 'Embedding' = proto.Field( + proto.MESSAGE, + number=1, + optional=True, + message='Embedding', + ) + + +class Embedding(proto.Message): + r"""A list of floats representing the embedding. + + Attributes: + value (MutableSequence[float]): + The embedding values. + """ + + value: MutableSequence[float] = proto.RepeatedField( + proto.FLOAT, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini b/owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini new file mode 100644 index 000000000000..574c5aed394b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +python_version = 3.7 +namespace_packages = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py new file mode 100644 index 000000000000..96375ae41831 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pathlib +import shutil +import subprocess +import sys + + +import nox # type: ignore + +ALL_PYTHON = [ + "3.7", + "3.8", + "3.9", + "3.10", + "3.11", +] + +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() + +LOWER_BOUND_CONSTRAINTS_FILE = CURRENT_DIRECTORY / "constraints.txt" +PACKAGE_NAME = subprocess.check_output([sys.executable, "setup.py", "--name"], encoding="utf-8") + +BLACK_VERSION = "black==22.3.0" +BLACK_PATHS = ["docs", "google", "tests", "samples", "noxfile.py", "setup.py"] +DEFAULT_PYTHON_VERSION = "3.11" + +nox.sessions = [ + "unit", + "cover", + "mypy", + "check_lower_bounds" + # exclude update_lower_bounds from default + "docs", + "blacken", + "lint", + "lint_setup_py", +] + +@nox.session(python=ALL_PYTHON) +def unit(session): + """Run the unit test suite.""" + + session.install('coverage', 'pytest', 'pytest-cov', 'pytest-asyncio', 'asyncmock; python_version < "3.8"') + session.install('-e', '.') + + session.run( + 'py.test', + '--quiet', + '--cov=google/ai/generativelanguage_v1beta2/', + '--cov=tests/', + '--cov-config=.coveragerc', + '--cov-report=term', + '--cov-report=html', + os.path.join('tests', 'unit', ''.join(session.posargs)) + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def cover(session): + """Run the final coverage report. + This outputs the coverage report aggregating coverage from the unit + test runs (not system test runs), and then erases coverage data. + """ + session.install("coverage", "pytest-cov") + session.run("coverage", "report", "--show-missing", "--fail-under=100") + + session.run("coverage", "erase") + + +@nox.session(python=ALL_PYTHON) +def mypy(session): + """Run the type checker.""" + session.install( + 'mypy', + 'types-requests', + 'types-protobuf' + ) + session.install('.') + session.run( + 'mypy', + '--explicit-package-bases', + 'google', + ) + + +@nox.session +def update_lower_bounds(session): + """Update lower bounds in constraints.txt to match setup.py""" + session.install('google-cloud-testutils') + session.install('.') + + session.run( + 'lower-bound-checker', + 'update', + '--package-name', + PACKAGE_NAME, + '--constraints-file', + str(LOWER_BOUND_CONSTRAINTS_FILE), + ) + + +@nox.session +def check_lower_bounds(session): + """Check lower bounds in setup.py are reflected in constraints file""" + session.install('google-cloud-testutils') + session.install('.') + + session.run( + 'lower-bound-checker', + 'check', + '--package-name', + PACKAGE_NAME, + '--constraints-file', + str(LOWER_BOUND_CONSTRAINTS_FILE), + ) + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def docs(session): + """Build the docs for this library.""" + + session.install("-e", ".") + session.install("sphinx==7.0.1", "alabaster", "recommonmark") + + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + session.run( + "sphinx-build", + "-W", # warnings as errors + "-T", # show full traceback on exception + "-N", # no colors + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint(session): + """Run linters. + + Returns a failure if the linters find linting errors or sufficiently + serious code quality issues. + """ + session.install("flake8", BLACK_VERSION) + session.run( + "black", + "--check", + *BLACK_PATHS, + ) + session.run("flake8", "google", "tests", "samples") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def blacken(session): + """Run black. Format code to uniform standard.""" + session.install(BLACK_VERSION) + session.run( + "black", + *BLACK_PATHS, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint_setup_py(session): + """Verify that setup.py is valid (including RST check).""" + session.install("docutils", "pygments") + session.run("python", "setup.py", "check", "--restructuredtext", "--strict") diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py new file mode 100644 index 000000000000..1b587e44368d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountMessageTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +async def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_message_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py new file mode 100644 index 000000000000..590d967fdfa6 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountMessageTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_message_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py new file mode 100644 index 000000000000..22848d706b77 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateMessage +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +async def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_message(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py new file mode 100644 index 000000000000..30106bdee93b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateMessage +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta2.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta2.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_message(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py new file mode 100644 index 000000000000..1eb30ff00aaa --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_ModelService_GetModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +async def sample_get_model(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.GetModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_ModelService_GetModel_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py new file mode 100644 index 000000000000..84eda9615b78 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_ModelService_GetModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +def sample_get_model(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.GetModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_ModelService_GetModel_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py new file mode 100644 index 000000000000..7d21ae65d7e6 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_ModelService_ListModels_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +async def sample_list_models(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1beta2_generated_ModelService_ListModels_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py new file mode 100644 index 000000000000..e94decf56a96 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_ModelService_ListModels_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +def sample_list_models(): + # Create a client + client = generativelanguage_v1beta2.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1beta2_generated_ModelService_ListModels_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py new file mode 100644 index 000000000000..d970ee8f589c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_TextService_EmbedText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +async def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = await client.embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_TextService_EmbedText_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py new file mode 100644 index 000000000000..c00795a1f795 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_TextService_EmbedText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta2.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = client.embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_TextService_EmbedText_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py new file mode 100644 index 000000000000..f41f480f205c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_TextService_GenerateText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +async def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta2.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_TextService_GenerateText_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py new file mode 100644 index 000000000000..900ed0003aeb --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta2_generated_TextService_GenerateText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta2 + + +def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta2.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta2.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta2.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta2_generated_TextService_GenerateText_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json new file mode 100644 index 000000000000..5b7d0a0509b4 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json @@ -0,0 +1,1093 @@ +{ + "clientLibrary": { + "apis": [ + { + "id": "google.ai.generativelanguage.v1beta2", + "version": "v1beta2" + } + ], + "language": "PYTHON", + "name": "google-ai-generativelanguage", + "version": "0.1.0" + }, + "snippets": [ + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient", + "shortName": "DiscussServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient.count_message_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.CountMessageTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "CountMessageTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse", + "shortName": "count_message_tokens" + }, + "description": "Sample for CountMessageTokens", + "file": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient", + "shortName": "DiscussServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient.count_message_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.CountMessageTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "CountMessageTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse", + "shortName": "count_message_tokens" + }, + "description": "Sample for CountMessageTokens", + "file": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient", + "shortName": "DiscussServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient.generate_message", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.GenerateMessage", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "GenerateMessage" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse", + "shortName": "generate_message" + }, + "description": "Sample for GenerateMessage", + "file": "generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient", + "shortName": "DiscussServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient.generate_message", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.GenerateMessage", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "GenerateMessage" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse", + "shortName": "generate_message" + }, + "description": "Sample for GenerateMessage", + "file": "generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient.get_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService.GetModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.GetModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.Model", + "shortName": "get_model" + }, + "description": "Sample for GetModel", + "file": "generativelanguage_v1beta2_generated_model_service_get_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_ModelService_GetModel_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_model_service_get_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient.get_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService.GetModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.GetModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.Model", + "shortName": "get_model" + }, + "description": "Sample for GetModel", + "file": "generativelanguage_v1beta2_generated_model_service_get_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_ModelService_GetModel_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_model_service_get_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient.list_models", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService.ListModels", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.ListModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsAsyncPager", + "shortName": "list_models" + }, + "description": "Sample for ListModels", + "file": "generativelanguage_v1beta2_generated_model_service_list_models_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_ModelService_ListModels_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_model_service_list_models_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient.list_models", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService.ListModels", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.ListModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsPager", + "shortName": "list_models" + }, + "description": "Sample for ListModels", + "file": "generativelanguage_v1beta2_generated_model_service_list_models_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_ModelService_ListModels_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_model_service_list_models_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient.embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService.EmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService", + "shortName": "TextService" + }, + "shortName": "EmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.EmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "text", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.EmbedTextResponse", + "shortName": "embed_text" + }, + "description": "Sample for EmbedText", + "file": "generativelanguage_v1beta2_generated_text_service_embed_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_TextService_EmbedText_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_text_service_embed_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient.embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService.EmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService", + "shortName": "TextService" + }, + "shortName": "EmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.EmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "text", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.EmbedTextResponse", + "shortName": "embed_text" + }, + "description": "Sample for EmbedText", + "file": "generativelanguage_v1beta2_generated_text_service_embed_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_TextService_EmbedText_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_text_service_embed_text_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient.generate_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService.GenerateText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService", + "shortName": "TextService" + }, + "shortName": "GenerateText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.GenerateTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta2.types.TextPrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "max_output_tokens", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateTextResponse", + "shortName": "generate_text" + }, + "description": "Sample for GenerateText", + "file": "generativelanguage_v1beta2_generated_text_service_generate_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_TextService_GenerateText_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_text_service_generate_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient.generate_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService.GenerateText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta2.TextService", + "shortName": "TextService" + }, + "shortName": "GenerateText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta2.types.GenerateTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta2.types.TextPrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "max_output_tokens", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateTextResponse", + "shortName": "generate_text" + }, + "description": "Sample for GenerateText", + "file": "generativelanguage_v1beta2_generated_text_service_generate_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta2_generated_TextService_GenerateText_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta2_generated_text_service_generate_text_sync.py" + } + ] +} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py new file mode 100644 index 000000000000..0c638051d5bf --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py @@ -0,0 +1,181 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +import os +import libcst as cst +import pathlib +import sys +from typing import (Any, Callable, Dict, List, Sequence, Tuple) + + +def partition( + predicate: Callable[[Any], bool], + iterator: Sequence[Any] +) -> Tuple[List[Any], List[Any]]: + """A stable, out-of-place partition.""" + results = ([], []) + + for i in iterator: + results[int(predicate(i))].append(i) + + # Returns trueList, falseList + return results[1], results[0] + + +class generativelanguageCallTransformer(cst.CSTTransformer): + CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') + METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { + 'count_message_tokens': ('model', 'prompt', ), + 'embed_text': ('model', 'text', ), + 'generate_message': ('model', 'prompt', 'temperature', 'candidate_count', 'top_p', 'top_k', ), + 'generate_text': ('model', 'prompt', 'temperature', 'candidate_count', 'max_output_tokens', 'top_p', 'top_k', 'safety_settings', 'stop_sequences', ), + 'get_model': ('name', ), + 'list_models': ('page_size', 'page_token', ), + } + + def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: + try: + key = original.func.attr.value + kword_params = self.METHOD_TO_PARAMS[key] + except (AttributeError, KeyError): + # Either not a method from the API or too convoluted to be sure. + return updated + + # If the existing code is valid, keyword args come after positional args. + # Therefore, all positional args must map to the first parameters. + args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) + if any(k.keyword.value == "request" for k in kwargs): + # We've already fixed this file, don't fix it again. + return updated + + kwargs, ctrl_kwargs = partition( + lambda a: a.keyword.value not in self.CTRL_PARAMS, + kwargs + ) + + args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] + ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) + for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) + + request_arg = cst.Arg( + value=cst.Dict([ + cst.DictElement( + cst.SimpleString("'{}'".format(name)), +cst.Element(value=arg.value) + ) + # Note: the args + kwargs looks silly, but keep in mind that + # the control parameters had to be stripped out, and that + # those could have been passed positionally or by keyword. + for name, arg in zip(kword_params, args + kwargs)]), + keyword=cst.Name("request") + ) + + return updated.with_changes( + args=[request_arg] + ctrl_kwargs + ) + + +def fix_files( + in_dir: pathlib.Path, + out_dir: pathlib.Path, + *, + transformer=generativelanguageCallTransformer(), +): + """Duplicate the input dir to the output dir, fixing file method calls. + + Preconditions: + * in_dir is a real directory + * out_dir is a real, empty directory + """ + pyfile_gen = ( + pathlib.Path(os.path.join(root, f)) + for root, _, files in os.walk(in_dir) + for f in files if os.path.splitext(f)[1] == ".py" + ) + + for fpath in pyfile_gen: + with open(fpath, 'r') as f: + src = f.read() + + # Parse the code and insert method call fixes. + tree = cst.parse_module(src) + updated = tree.visit(transformer) + + # Create the path and directory structure for the new file. + updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) + updated_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate the updated source file at the corresponding path. + with open(updated_path, 'w') as f: + f.write(updated.code) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="""Fix up source that uses the generativelanguage client library. + +The existing sources are NOT overwritten but are copied to output_dir with changes made. + +Note: This tool operates at a best-effort level at converting positional + parameters in client method calls to keyword based parameters. + Cases where it WILL FAIL include + A) * or ** expansion in a method call. + B) Calls via function or method alias (includes free function calls) + C) Indirect or dispatched calls (e.g. the method is looked up dynamically) + + These all constitute false negatives. The tool will also detect false + positives when an API method shares a name with another method. +""") + parser.add_argument( + '-d', + '--input-directory', + required=True, + dest='input_dir', + help='the input directory to walk for python files to fix up', + ) + parser.add_argument( + '-o', + '--output-directory', + required=True, + dest='output_dir', + help='the directory to output files fixed via un-flattening', + ) + args = parser.parse_args() + input_dir = pathlib.Path(args.input_dir) + output_dir = pathlib.Path(args.output_dir) + if not input_dir.is_dir(): + print( + f"input directory '{input_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if not output_dir.is_dir(): + print( + f"output directory '{output_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if os.listdir(output_dir): + print( + f"output directory '{output_dir}' is not empty", + file=sys.stderr, + ) + sys.exit(-1) + + fix_files(input_dir, output_dir) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py new file mode 100644 index 000000000000..0e0b1e55d45f --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import io +import os + +import setuptools # type: ignore + +package_root = os.path.abspath(os.path.dirname(__file__)) + +name = 'google-ai-generativelanguage' + + +description = "Google Ai Generativelanguage API client library" + +version = {} +with open(os.path.join(package_root, 'google/ai/generativelanguage/gapic_version.py')) as fp: + exec(fp.read(), version) +version = version["__version__"] + +if version[0] == "0": + release_status = "Development Status :: 4 - Beta" +else: + release_status = "Development Status :: 5 - Production/Stable" + +dependencies = [ + "google-api-core[grpc] >= 1.34.0, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*", + "proto-plus >= 1.22.0, <2.0.0dev", + "proto-plus >= 1.22.2, <2.0.0dev; python_version>='3.11'", + "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", +] +url = "https://github.com/googleapis/python-ai-generativelanguage" + +package_root = os.path.abspath(os.path.dirname(__file__)) + +readme_filename = os.path.join(package_root, "README.rst") +with io.open(readme_filename, encoding="utf-8") as readme_file: + readme = readme_file.read() + +packages = [ + package + for package in setuptools.PEP420PackageFinder.find() + if package.startswith("google") +] + +namespaces = ["google", "google.ai"] + +setuptools.setup( + name=name, + version=version, + description=description, + long_description=readme, + author="Google LLC", + author_email="googleapis-packages@google.com", + license="Apache 2.0", + url=url, + classifiers=[ + release_status, + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Operating System :: OS Independent", + "Topic :: Internet", + ], + platforms="Posix; MacOS X; Windows", + packages=packages, + python_requires=">=3.7", + namespace_packages=namespaces, + install_requires=dependencies, + include_package_data=True, + zip_safe=False, +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt new file mode 100644 index 000000000000..6c44adfea7ee --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt @@ -0,0 +1,9 @@ +# This constraints file is used to check that lower bounds +# are correct in setup.py +# List all library dependencies and extras in this file. +# Pin the version to the lower bound. +# e.g., if setup.py has "google-cloud-foo >= 1.14.0, < 2.0.0dev", +# Then this file should have google-cloud-foo==1.14.0 +google-api-core==1.34.0 +proto-plus==1.22.0 +protobuf==3.19.5 diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py new file mode 100644 index 000000000000..fa35eaf42fd5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py @@ -0,0 +1,2205 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.ai.generativelanguage_v1beta2.services.discuss_service import DiscussServiceAsyncClient +from google.ai.generativelanguage_v1beta2.services.discuss_service import DiscussServiceClient +from google.ai.generativelanguage_v1beta2.services.discuss_service import transports +from google.ai.generativelanguage_v1beta2.types import citation +from google.ai.generativelanguage_v1beta2.types import discuss_service +from google.ai.generativelanguage_v1beta2.types import safety +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.oauth2 import service_account +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert DiscussServiceClient._get_default_mtls_endpoint(None) is None + assert DiscussServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class,transport_name", [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), +]) +def test_discuss_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +@pytest.mark.parametrize("transport_class,transport_name", [ + (transports.DiscussServiceGrpcTransport, "grpc"), + (transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.DiscussServiceRestTransport, "rest"), +]) +def test_discuss_service_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class,transport_name", [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), +]) +def test_discuss_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +def test_discuss_service_client_get_transport_class(): + transport = DiscussServiceClient.get_transport_class() + available_transports = [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceRestTransport, + ] + assert transport in available_transports + + transport = DiscussServiceClient.get_transport_class("grpc") + assert transport == transports.DiscussServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), +]) +@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) +@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) +def test_discuss_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions(api_audience="https://language.googleapis.com") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com" + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "true"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "false"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "true"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "false"), +]) +@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) +@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_discuss_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ + DiscussServiceClient, DiscussServiceAsyncClient +]) +@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) +@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) +def test_discuss_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), +]) +def test_discuss_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", None), +]) +def test_discuss_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +def test_discuss_service_client_client_options_from_dict(): + with mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = DiscussServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), +]) +def test_discuss_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( +), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.GenerateMessageRequest, + dict, +]) +def test_generate_message(request_type, transport: str = 'grpc'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse( + ) + response = client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.GenerateMessageRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +def test_generate_message_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + client.generate_message() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.GenerateMessageRequest() + +@pytest.mark.asyncio +async def test_generate_message_async(transport: str = 'grpc_asyncio', request_type=discuss_service.GenerateMessageRequest): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse( + )) + response = await client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.GenerateMessageRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +@pytest.mark.asyncio +async def test_generate_message_async_from_dict(): + await test_generate_message_async(request_type=dict) + + +def test_generate_message_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.GenerateMessageRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + call.return_value = discuss_service.GenerateMessageResponse() + client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_generate_message_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.GenerateMessageRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) + await client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_generate_message_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_message( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +def test_generate_message_flattened_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_message( + discuss_service.GenerateMessageRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + +@pytest.mark.asyncio +async def test_generate_message_flattened_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_message( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + +@pytest.mark.asyncio +async def test_generate_message_flattened_error_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_message( + discuss_service.GenerateMessageRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.CountMessageTokensRequest, + dict, +]) +def test_count_message_tokens(request_type, transport: str = 'grpc'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + response = client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.CountMessageTokensRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +def test_count_message_tokens_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + client.count_message_tokens() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.CountMessageTokensRequest() + +@pytest.mark.asyncio +async def test_count_message_tokens_async(transport: str = 'grpc_asyncio', request_type=discuss_service.CountMessageTokensRequest): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse( + token_count=1193, + )) + response = await client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.CountMessageTokensRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +@pytest.mark.asyncio +async def test_count_message_tokens_async_from_dict(): + await test_count_message_tokens_async(request_type=dict) + + +def test_count_message_tokens_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.CountMessageTokensRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + call.return_value = discuss_service.CountMessageTokensResponse() + client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_count_message_tokens_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.CountMessageTokensRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) + await client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_count_message_tokens_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.count_message_tokens( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + + +def test_count_message_tokens_flattened_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + +@pytest.mark.asyncio +async def test_count_message_tokens_flattened_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.count_message_tokens( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + +@pytest.mark.asyncio +async def test_count_message_tokens_flattened_error_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.GenerateMessageRequest, + dict, +]) +def test_generate_message_rest(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.generate_message(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +def test_generate_message_rest_required_fields(request_type=discuss_service.GenerateMessageRequest): + transport_class = transports.DiscussServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.generate_message(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_generate_message_rest_unset_required_fields(): + transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.generate_message._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_message_rest_interceptors(null_interceptor): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), + ) + client = DiscussServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "post_generate_message") as post, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_generate_message") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = discuss_service.GenerateMessageRequest.pb(discuss_service.GenerateMessageRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = discuss_service.GenerateMessageResponse.to_json(discuss_service.GenerateMessageResponse()) + + request = discuss_service.GenerateMessageRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = discuss_service.GenerateMessageResponse() + + client.generate_message(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_generate_message_rest_bad_request(transport: str = 'rest', request_type=discuss_service.GenerateMessageRequest): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.generate_message(request) + + +def test_generate_message_rest_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.generate_message(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta2/{model=models/*}:generateMessage" % client.transport._host, args[1]) + + +def test_generate_message_rest_flattened_error(transport: str = 'rest'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_message( + discuss_service.GenerateMessageRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + +def test_generate_message_rest_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.CountMessageTokensRequest, + dict, +]) +def test_count_message_tokens_rest(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.count_message_tokens(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +def test_count_message_tokens_rest_required_fields(request_type=discuss_service.CountMessageTokensRequest): + transport_class = transports.DiscussServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.count_message_tokens(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_count_message_tokens_rest_unset_required_fields(): + transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.count_message_tokens._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_count_message_tokens_rest_interceptors(null_interceptor): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), + ) + client = DiscussServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "post_count_message_tokens") as post, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_count_message_tokens") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = discuss_service.CountMessageTokensRequest.pb(discuss_service.CountMessageTokensRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = discuss_service.CountMessageTokensResponse.to_json(discuss_service.CountMessageTokensResponse()) + + request = discuss_service.CountMessageTokensRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = discuss_service.CountMessageTokensResponse() + + client.count_message_tokens(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_count_message_tokens_rest_bad_request(transport: str = 'rest', request_type=discuss_service.CountMessageTokensRequest): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.count_message_tokens(request) + + +def test_count_message_tokens_rest_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.count_message_tokens(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta2/{model=models/*}:countMessageTokens" % client.transport._host, args[1]) + + +def test_count_message_tokens_rest_flattened_error(transport: str = 'rest'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + +def test_count_message_tokens_rest_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options=options, + credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = DiscussServiceClient(transport=transport) + assert client.transport is transport + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.DiscussServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + +@pytest.mark.parametrize("transport_class", [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + transports.DiscussServiceRestTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "rest", +]) +def test_transport_kind(transport_name): + transport = DiscussServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.DiscussServiceGrpcTransport, + ) + +def test_discuss_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.DiscussServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_discuss_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.DiscussServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'generate_message', + 'count_message_tokens', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + 'kind', + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_discuss_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.DiscussServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", + scopes=None, + default_scopes=( +), + quota_project_id="octopus", + ) + + +def test_discuss_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.DiscussServiceTransport() + adc.assert_called_once() + + +def test_discuss_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + DiscussServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=( +), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + transports.DiscussServiceRestTransport, + ], +) +def test_discuss_service_transport_auth_gdch_credentials(transport_class): + host = 'https://language.com' + api_audience_tests = [None, 'https://language2.com'] + api_audience_expect = [host, 'https://language2.com'] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with( + e + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.DiscussServiceGrpcTransport, grpc_helpers), + (transports.DiscussServiceGrpcAsyncIOTransport, grpc_helpers_async) + ], +) +def test_discuss_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class( + quota_project_id="octopus", + scopes=["1", "2"] + ) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( +), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) +def test_discuss_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + +def test_discuss_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.DiscussServiceRestTransport ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_discuss_service_host_no_port(transport_name): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com' + ) + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_discuss_service_host_with_port(transport_name): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:8000' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com:8000' + ) + +@pytest.mark.parametrize("transport_name", [ + "rest", +]) +def test_discuss_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = DiscussServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = DiscussServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.generate_message._session + session2 = client2.transport.generate_message._session + assert session1 != session2 + session1 = client1.transport.count_message_tokens._session + session2 = client2.transport.count_message_tokens._session + assert session1 != session2 +def test_discuss_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DiscussServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_discuss_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DiscussServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) +def test_discuss_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) +def test_discuss_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format(model=model, ) + actual = DiscussServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = DiscussServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = DiscussServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = DiscussServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format(folder=folder, ) + actual = DiscussServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = DiscussServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format(organization=organization, ) + actual = DiscussServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = DiscussServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format(project=project, ) + actual = DiscussServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = DiscussServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = DiscussServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = DiscussServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = DiscussServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with client: + close.assert_not_called() + close.assert_called_once() + +def test_client_ctx(): + transports = [ + 'rest', + 'grpc', + ] + for transport in transports: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + +@pytest.mark.parametrize("client_class,transport_class", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport), +]) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py new file mode 100644 index 000000000000..c7a1ee1f30f8 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py @@ -0,0 +1,2319 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.ai.generativelanguage_v1beta2.services.model_service import ModelServiceAsyncClient +from google.ai.generativelanguage_v1beta2.services.model_service import ModelServiceClient +from google.ai.generativelanguage_v1beta2.services.model_service import pagers +from google.ai.generativelanguage_v1beta2.services.model_service import transports +from google.ai.generativelanguage_v1beta2.types import model +from google.ai.generativelanguage_v1beta2.types import model_service +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.oauth2 import service_account +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ModelServiceClient._get_default_mtls_endpoint(None) is None + assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class,transport_name", [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), +]) +def test_model_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +@pytest.mark.parametrize("transport_class,transport_name", [ + (transports.ModelServiceGrpcTransport, "grpc"), + (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.ModelServiceRestTransport, "rest"), +]) +def test_model_service_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class,transport_name", [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), +]) +def test_model_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +def test_model_service_client_get_transport_class(): + transport = ModelServiceClient.get_transport_class() + available_transports = [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceRestTransport, + ] + assert transport in available_transports + + transport = ModelServiceClient.get_transport_class("grpc") + assert transport == transports.ModelServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions(api_audience="https://language.googleapis.com") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com" + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "true"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "false"), +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ + ModelServiceClient, ModelServiceAsyncClient +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), +]) +def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", None), +]) +def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +def test_model_service_client_client_options_from_dict(): + with mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = ModelServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), +]) +def test_model_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( +), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.GetModelRequest, + dict, +]) +def test_get_model(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model( + name='name_value', + base_model_id='base_model_id_value', + version='version_value', + display_name='display_name_value', + description='description_value', + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=['supported_generation_methods_value'], + temperature=0.1198, + top_p=0.546, + top_k=541, + ) + response = client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == 'name_value' + assert response.base_model_id == 'base_model_id_value' + assert response.version == 'version_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +def test_get_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + client.get_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetModelRequest() + +@pytest.mark.asyncio +async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model.Model( + name='name_value', + base_model_id='base_model_id_value', + version='version_value', + display_name='display_name_value', + description='description_value', + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=['supported_generation_methods_value'], + temperature=0.1198, + top_p=0.546, + top_k=541, + )) + response = await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == 'name_value' + assert response.base_model_id == 'base_model_id_value' + assert response.version == 'version_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +@pytest.mark.asyncio +async def test_get_model_async_from_dict(): + await test_get_model_async(request_type=dict) + + +def test_get_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + call.return_value = model.Model() + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +def test_get_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + + +def test_get_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + +@pytest.mark.asyncio +async def test_get_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_get_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.ListModelsRequest, + dict, +]) +def test_list_models(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse( + next_page_token='next_page_token_value', + ) + response = client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + assert response.next_page_token == 'next_page_token_value' + + +def test_list_models_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + client.list_models() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListModelsRequest() + +@pytest.mark.asyncio +async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( + next_page_token='next_page_token_value', + )) + response = await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsAsyncPager) + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_models_async_from_dict(): + await test_list_models_async(request_type=dict) + + +def test_list_models_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_models( + page_size=951, + page_token='page_token_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = 'page_token_value' + assert arg == mock_val + + +def test_list_models_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + +@pytest.mark.asyncio +async def test_list_models_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_models( + page_size=951, + page_token='page_token_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = 'page_token_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_list_models_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + + +def test_list_models_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + + metadata = () + pager = client.list_models(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, model.Model) + for i in results) +def test_list_models_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = list(client.list_models(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_models_async_pager(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_models(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model.Model) + for i in responses) + + +@pytest.mark.asyncio +async def test_list_models_async_pages(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_models(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize("request_type", [ + model_service.GetModelRequest, + dict, +]) +def test_get_model_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model.Model( + name='name_value', + base_model_id='base_model_id_value', + version='version_value', + display_name='display_name_value', + description='description_value', + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=['supported_generation_methods_value'], + temperature=0.1198, + top_p=0.546, + top_k=541, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.get_model(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == 'name_value' + assert response.base_model_id == 'base_model_id_value' + assert response.version == 'version_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +def test_get_model_rest_required_fields(request_type=model_service.GetModelRequest): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = 'name_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == 'name_value' + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = model.Model() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "get", + 'query_params': pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.get_model(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_get_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.get_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_get_model") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_get_model") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.GetModelRequest.pb(model_service.GetModelRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = model.Model.to_json(model.Model()) + + request = model_service.GetModelRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model.Model() + + client.get_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_model_rest_bad_request(transport: str = 'rest', request_type=model_service.GetModelRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_model(request) + + +def test_get_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model.Model() + + # get arguments that satisfy an http rule for this method + sample_request = {'name': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + name='name_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.get_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta2/{name=models/*}" % client.transport._host, args[1]) + + +def test_get_model_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + + +def test_get_model_rest_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.ListModelsRequest, + dict, +]) +def test_list_models_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListModelsResponse( + next_page_token='next_page_token_value', + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model_service.ListModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.list_models(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_models_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_list_models") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_list_models") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.ListModelsRequest.pb(model_service.ListModelsRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = model_service.ListModelsResponse.to_json(model_service.ListModelsResponse()) + + request = model_service.ListModelsRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model_service.ListModelsResponse() + + client.list_models(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_models_rest_bad_request(transport: str = 'rest', request_type=model_service.ListModelsRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_models(request) + + +def test_list_models_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListModelsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + page_size=951, + page_token='page_token_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model_service.ListModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.list_models(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta2/models" % client.transport._host, args[1]) + + +def test_list_models_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + + +def test_list_models_rest_pager(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + #with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(model_service.ListModelsResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode('UTF-8') + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_models(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, model.Model) + for i in results) + + pages = list(client.list_models(request=sample_request).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, + credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = ModelServiceClient(transport=transport) + assert client.transport is transport + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ModelServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + +@pytest.mark.parametrize("transport_class", [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + transports.ModelServiceRestTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "rest", +]) +def test_transport_kind(transport_name): + transport = ModelServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelServiceGrpcTransport, + ) + +def test_model_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.ModelServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_model_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.ModelServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'get_model', + 'list_models', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + 'kind', + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_model_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", + scopes=None, + default_scopes=( +), + quota_project_id="octopus", + ) + + +def test_model_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport() + adc.assert_called_once() + + +def test_model_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + ModelServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=( +), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + ], +) +def test_model_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + transports.ModelServiceRestTransport, + ], +) +def test_model_service_transport_auth_gdch_credentials(transport_class): + host = 'https://language.com' + api_audience_tests = [None, 'https://language2.com'] + api_audience_expect = [host, 'https://language2.com'] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with( + e + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.ModelServiceGrpcTransport, grpc_helpers), + (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async) + ], +) +def test_model_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class( + quota_project_id="octopus", + scopes=["1", "2"] + ) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( +), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + +def test_model_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.ModelServiceRestTransport ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_model_service_host_no_port(transport_name): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com' + ) + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_model_service_host_with_port(transport_name): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:8000' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com:8000' + ) + +@pytest.mark.parametrize("transport_name", [ + "rest", +]) +def test_model_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = ModelServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = ModelServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.get_model._session + session2 = client2.transport.get_model._session + assert session1 != session2 + session1 = client1.transport.list_models._session + session2 = client2.transport.list_models._session + assert session1 != session2 +def test_model_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_model_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format(model=model, ) + actual = ModelServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = ModelServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = ModelServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = ModelServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format(folder=folder, ) + actual = ModelServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = ModelServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format(organization=organization, ) + actual = ModelServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = ModelServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format(project=project, ) + actual = ModelServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = ModelServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = ModelServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = ModelServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = ModelServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with client: + close.assert_not_called() + close.assert_called_once() + +def test_client_ctx(): + transports = [ + 'rest', + 'grpc', + ] + for transport in transports: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + +@pytest.mark.parametrize("client_class,transport_class", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), +]) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py new file mode 100644 index 000000000000..2fbd8b3036c2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py @@ -0,0 +1,2214 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.ai.generativelanguage_v1beta2.services.text_service import TextServiceAsyncClient +from google.ai.generativelanguage_v1beta2.services.text_service import TextServiceClient +from google.ai.generativelanguage_v1beta2.services.text_service import transports +from google.ai.generativelanguage_v1beta2.types import safety +from google.ai.generativelanguage_v1beta2.types import text_service +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.oauth2 import service_account +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert TextServiceClient._get_default_mtls_endpoint(None) is None + assert TextServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class,transport_name", [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), +]) +def test_text_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +@pytest.mark.parametrize("transport_class,transport_name", [ + (transports.TextServiceGrpcTransport, "grpc"), + (transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.TextServiceRestTransport, "rest"), +]) +def test_text_service_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class,transport_name", [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), +]) +def test_text_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +def test_text_service_client_get_transport_class(): + transport = TextServiceClient.get_transport_class() + available_transports = [ + transports.TextServiceGrpcTransport, + transports.TextServiceRestTransport, + ] + assert transport in available_transports + + transport = TextServiceClient.get_transport_class("grpc") + assert transport == transports.TextServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), +]) +@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) +@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) +def test_text_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions(api_audience="https://language.googleapis.com") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com" + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "true"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "false"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "true"), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "false"), +]) +@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) +@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_text_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ + TextServiceClient, TextServiceAsyncClient +]) +@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) +@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) +def test_text_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), +]) +def test_text_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), + (TextServiceClient, transports.TextServiceRestTransport, "rest", None), +]) +def test_text_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +def test_text_service_client_client_options_from_dict(): + with mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = TextServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), +]) +def test_text_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( +), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.GenerateTextRequest, + dict, +]) +def test_generate_text(request_type, transport: str = 'grpc'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse( + ) + response = client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.GenerateTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +def test_generate_text_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + client.generate_text() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.GenerateTextRequest() + +@pytest.mark.asyncio +async def test_generate_text_async(transport: str = 'grpc_asyncio', request_type=text_service.GenerateTextRequest): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse( + )) + response = await client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.GenerateTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +@pytest.mark.asyncio +async def test_generate_text_async_from_dict(): + await test_generate_text_async(request_type=dict) + + +def test_generate_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.GenerateTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + call.return_value = text_service.GenerateTextResponse() + client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_generate_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.GenerateTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) + await client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_generate_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_text( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text='text_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + arg = args[0].max_output_tokens + mock_val = 1865 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +def test_generate_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_text( + text_service.GenerateTextRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + +@pytest.mark.asyncio +async def test_generate_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_text( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text='text_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + arg = args[0].max_output_tokens + mock_val = 1865 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + +@pytest.mark.asyncio +async def test_generate_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_text( + text_service.GenerateTextRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.EmbedTextRequest, + dict, +]) +def test_embed_text(request_type, transport: str = 'grpc'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse( + ) + response = client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.EmbedTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +def test_embed_text_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + client.embed_text() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.EmbedTextRequest() + +@pytest.mark.asyncio +async def test_embed_text_async(transport: str = 'grpc_asyncio', request_type=text_service.EmbedTextRequest): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse( + )) + response = await client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.EmbedTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +@pytest.mark.asyncio +async def test_embed_text_async_from_dict(): + await test_embed_text_async(request_type=dict) + + +def test_embed_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.EmbedTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + call.return_value = text_service.EmbedTextResponse() + client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_embed_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.EmbedTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) + await client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_embed_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.embed_text( + model='model_value', + text='text_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].text + mock_val = 'text_value' + assert arg == mock_val + + +def test_embed_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_text( + text_service.EmbedTextRequest(), + model='model_value', + text='text_value', + ) + +@pytest.mark.asyncio +async def test_embed_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.embed_text( + model='model_value', + text='text_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].text + mock_val = 'text_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_embed_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.embed_text( + text_service.EmbedTextRequest(), + model='model_value', + text='text_value', + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.GenerateTextRequest, + dict, +]) +def test_generate_text_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.generate_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +def test_generate_text_rest_required_fields(request_type=text_service.GenerateTextRequest): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.generate_text(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_generate_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.generate_text._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.TextServiceRestInterceptor, "post_generate_text") as post, \ + mock.patch.object(transports.TextServiceRestInterceptor, "pre_generate_text") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.GenerateTextRequest.pb(text_service.GenerateTextRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = text_service.GenerateTextResponse.to_json(text_service.GenerateTextResponse()) + + request = text_service.GenerateTextRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.GenerateTextResponse() + + client.generate_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_generate_text_rest_bad_request(transport: str = 'rest', request_type=text_service.GenerateTextRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.generate_text(request) + + +def test_generate_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.generate_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta2/{model=models/*}:generateText" % client.transport._host, args[1]) + + +def test_generate_text_rest_flattened_error(transport: str = 'rest'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_text( + text_service.GenerateTextRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + +def test_generate_text_rest_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.EmbedTextRequest, + dict, +]) +def test_embed_text_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.embed_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +def test_embed_text_rest_required_fields(request_type=text_service.EmbedTextRequest): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request_init["text"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + jsonified_request["text"] = 'text_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + assert "text" in jsonified_request + assert jsonified_request["text"] == 'text_value' + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.embed_text(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_embed_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.embed_text._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "text", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_embed_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.TextServiceRestInterceptor, "post_embed_text") as post, \ + mock.patch.object(transports.TextServiceRestInterceptor, "pre_embed_text") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.EmbedTextRequest.pb(text_service.EmbedTextRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = text_service.EmbedTextResponse.to_json(text_service.EmbedTextResponse()) + + request = text_service.EmbedTextRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.EmbedTextResponse() + + client.embed_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_embed_text_rest_bad_request(transport: str = 'rest', request_type=text_service.EmbedTextRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.embed_text(request) + + +def test_embed_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + text='text_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.embed_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta2/{model=models/*}:embedText" % client.transport._host, args[1]) + + +def test_embed_text_rest_flattened_error(transport: str = 'rest'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_text( + text_service.EmbedTextRequest(), + model='model_value', + text='text_value', + ) + + +def test_embed_text_rest_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TextServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TextServiceClient( + client_options=options, + credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = TextServiceClient(transport=transport) + assert client.transport is transport + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.TextServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + +@pytest.mark.parametrize("transport_class", [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + transports.TextServiceRestTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "rest", +]) +def test_transport_kind(transport_name): + transport = TextServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.TextServiceGrpcTransport, + ) + +def test_text_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.TextServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_text_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.TextServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'generate_text', + 'embed_text', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + 'kind', + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_text_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TextServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", + scopes=None, + default_scopes=( +), + quota_project_id="octopus", + ) + + +def test_text_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TextServiceTransport() + adc.assert_called_once() + + +def test_text_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + TextServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=( +), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + ], +) +def test_text_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + transports.TextServiceRestTransport, + ], +) +def test_text_service_transport_auth_gdch_credentials(transport_class): + host = 'https://language.com' + api_audience_tests = [None, 'https://language2.com'] + api_audience_expect = [host, 'https://language2.com'] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with( + e + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.TextServiceGrpcTransport, grpc_helpers), + (transports.TextServiceGrpcAsyncIOTransport, grpc_helpers_async) + ], +) +def test_text_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class( + quota_project_id="octopus", + scopes=["1", "2"] + ) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( +), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) +def test_text_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + +def test_text_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.TextServiceRestTransport ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_text_service_host_no_port(transport_name): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com' + ) + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_text_service_host_with_port(transport_name): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:8000' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com:8000' + ) + +@pytest.mark.parametrize("transport_name", [ + "rest", +]) +def test_text_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = TextServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = TextServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.generate_text._session + session2 = client2.transport.generate_text._session + assert session1 != session2 + session1 = client1.transport.embed_text._session + session2 = client2.transport.embed_text._session + assert session1 != session2 +def test_text_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TextServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_text_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TextServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) +def test_text_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) +def test_text_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format(model=model, ) + actual = TextServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = TextServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = TextServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = TextServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format(folder=folder, ) + actual = TextServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = TextServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format(organization=organization, ) + actual = TextServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = TextServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format(project=project, ) + actual = TextServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = TextServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = TextServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = TextServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = TextServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with client: + close.assert_not_called() + close.assert_called_once() + +def test_client_ctx(): + transports = [ + 'rest', + 'grpc', + ] + for transport in transports: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + +@pytest.mark.parametrize("client_class,transport_class", [ + (TextServiceClient, transports.TextServiceGrpcTransport), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport), +]) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc b/owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc new file mode 100644 index 000000000000..fd060ae956b5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc @@ -0,0 +1,13 @@ +[run] +branch = True + +[report] +show_missing = True +omit = + google/ai/generativelanguage/__init__.py + google/ai/generativelanguage/gapic_version.py +exclude_lines = + # Re-enable the standard pragma + pragma: NO COVER + # Ignore debug-only repr + def __repr__ diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 b/owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 new file mode 100644 index 000000000000..29227d4cf419 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by synthtool. DO NOT EDIT! +[flake8] +ignore = E203, E266, E501, W503 +exclude = + # Exclude generated code. + **/proto/** + **/gapic/** + **/services/** + **/types/** + *_pb2.py + + # Standard linting exemptions. + **/.nox/** + __pycache__, + .git, + *.pyc, + conf.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in b/owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in new file mode 100644 index 000000000000..a41cec0defac --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in @@ -0,0 +1,2 @@ +recursive-include google/ai/generativelanguage *.py +recursive-include google/ai/generativelanguage_v1beta3 *.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst new file mode 100644 index 000000000000..099f73894711 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst @@ -0,0 +1,49 @@ +Python Client for Google Ai Generativelanguage API +================================================= + +Quick Start +----------- + +In order to use this library, you first need to go through the following steps: + +1. `Select or create a Cloud Platform project.`_ +2. `Enable billing for your project.`_ +3. Enable the Google Ai Generativelanguage API. +4. `Setup Authentication.`_ + +.. _Select or create a Cloud Platform project.: https://console.cloud.google.com/project +.. _Enable billing for your project.: https://cloud.google.com/billing/docs/how-to/modify-project#enable_billing_for_a_project +.. _Setup Authentication.: https://googleapis.dev/python/google-api-core/latest/auth.html + +Installation +~~~~~~~~~~~~ + +Install this library in a `virtualenv`_ using pip. `virtualenv`_ is a tool to +create isolated Python environments. The basic problem it addresses is one of +dependencies and versions, and indirectly permissions. + +With `virtualenv`_, it's possible to install this library without needing system +install permissions, and without clashing with the installed system +dependencies. + +.. _`virtualenv`: https://virtualenv.pypa.io/en/latest/ + + +Mac/Linux +^^^^^^^^^ + +.. code-block:: console + + python3 -m venv + source /bin/activate + /bin/pip install /path/to/library + + +Windows +^^^^^^^ + +.. code-block:: console + + python3 -m venv + \Scripts\activate + \Scripts\pip.exe install \path\to\library diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css new file mode 100644 index 000000000000..06423be0b592 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css @@ -0,0 +1,3 @@ +dl.field-list > dt { + min-width: 100px +} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py new file mode 100644 index 000000000000..0f3f4903ff54 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py @@ -0,0 +1,376 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# google-ai-generativelanguage documentation build configuration file +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import sys +import os +import shlex + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.abspath("..")) + +__version__ = "0.1.0" + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +needs_sphinx = "4.0.1" + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.coverage", + "sphinx.ext.napoleon", + "sphinx.ext.todo", + "sphinx.ext.viewcode", +] + +# autodoc/autosummary flags +autoclass_content = "both" +autodoc_default_flags = ["members"] +autosummary_generate = True + + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# Allow markdown includes (so releases.md can include CHANGLEOG.md) +# http://www.sphinx-doc.org/en/master/markdown.html +source_parsers = {".md": "recommonmark.parser.CommonMarkParser"} + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +source_suffix = [".rst", ".md"] + +# The encoding of source files. +# source_encoding = 'utf-8-sig' + +# The root toctree document. +root_doc = "index" + +# General information about the project. +project = u"google-ai-generativelanguage" +copyright = u"2023, Google, LLC" +author = u"Google APIs" # TODO: autogenerate this bit + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The full version, including alpha/beta/rc tags. +release = __version__ +# The short X.Y version. +version = ".".join(release.split(".")[0:2]) + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'en' + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ["_build"] + +# The reST default role (used for this markup: `text`) to use for all +# documents. +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +# keep_warnings = False + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "alabaster" + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + "description": "Google Ai Client Libraries for Python", + "github_user": "googleapis", + "github_repo": "google-cloud-python", + "github_banner": True, + "font_family": "'Roboto', Georgia, sans", + "head_font_family": "'Roboto', Georgia, serif", + "code_font_family": "'Roboto Mono', 'Consolas', monospace", +} + +# Add any paths that contain custom themes here, relative to this directory. +# html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +# html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +# html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +# html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +# html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# Add any extra paths that contain custom files (such as robots.txt or +# .htaccess) here, relative to this directory. These files are copied +# directly to the root of the documentation. +# html_extra_path = [] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +# html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +# html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +# html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +# html_additional_pages = {} + +# If false, no module index is generated. +# html_domain_indices = True + +# If false, no index is generated. +# html_use_index = True + +# If true, the index is split into individual pages for each letter. +# html_split_index = False + +# If true, links to the reST sources are added to the pages. +# html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +# html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +# html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +# html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +# html_file_suffix = None + +# Language to be used for generating the HTML full-text search index. +# Sphinx supports the following languages: +# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' +# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' +# html_search_language = 'en' + +# A dictionary with options for the search language support, empty by default. +# Now only 'ja' uses this config value +# html_search_options = {'type': 'default'} + +# The name of a javascript file (relative to the configuration directory) that +# implements a search results scorer. If empty, the default will be used. +# html_search_scorer = 'scorer.js' + +# Output file base name for HTML help builder. +htmlhelp_basename = "google-ai-generativelanguage-doc" + +# -- Options for warnings ------------------------------------------------------ + + +suppress_warnings = [ + # Temporarily suppress this to avoid "more than one target found for + # cross-reference" warning, which are intractable for us to avoid while in + # a mono-repo. + # See https://github.com/sphinx-doc/sphinx/blob + # /2a65ffeef5c107c19084fabdd706cdff3f52d93c/sphinx/domains/python.py#L843 + "ref.python" +] + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # 'preamble': '', + # Latex figure (float) alignment + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + ( + root_doc, + "google-ai-generativelanguage.tex", + u"google-ai-generativelanguage Documentation", + author, + "manual", + ) +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +# latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +# latex_use_parts = False + +# If true, show page references after internal links. +# latex_show_pagerefs = False + +# If true, show URL addresses after external links. +# latex_show_urls = False + +# Documents to append as an appendix to all manuals. +# latex_appendices = [] + +# If false, no module index is generated. +# latex_domain_indices = True + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ( + root_doc, + "google-ai-generativelanguage", + u"Google Ai Generativelanguage Documentation", + [author], + 1, + ) +] + +# If true, show URL addresses after external links. +# man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + root_doc, + "google-ai-generativelanguage", + u"google-ai-generativelanguage Documentation", + author, + "google-ai-generativelanguage", + "GAPIC library for Google Ai Generativelanguage API", + "APIs", + ) +] + +# Documents to append as an appendix to all manuals. +# texinfo_appendices = [] + +# If false, no module index is generated. +# texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +# texinfo_show_urls = 'footnote' + +# If true, do not generate a @detailmenu in the "Top" node's menu. +# texinfo_no_detailmenu = False + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + "python": ("http://python.readthedocs.org/en/latest/", None), + "gax": ("https://gax-python.readthedocs.org/en/latest/", None), + "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), + "google-gax": ("https://gax-python.readthedocs.io/en/latest/", None), + "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None), + "grpc": ("https://grpc.io/grpc/python/", None), + "requests": ("http://requests.kennethreitz.org/en/stable/", None), + "proto": ("https://proto-plus-python.readthedocs.io/en/stable", None), + "protobuf": ("https://googleapis.dev/python/protobuf/latest/", None), +} + + +# Napoleon settings +napoleon_google_docstring = True +napoleon_numpy_docstring = True +napoleon_include_private_with_doc = False +napoleon_include_special_with_doc = True +napoleon_use_admonition_for_examples = False +napoleon_use_admonition_for_notes = False +napoleon_use_admonition_for_references = False +napoleon_use_ivar = False +napoleon_use_param = True +napoleon_use_rtype = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/discuss_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/discuss_service.rst new file mode 100644 index 000000000000..8da7dd42804a --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/discuss_service.rst @@ -0,0 +1,6 @@ +DiscussService +-------------------------------- + +.. automodule:: google.ai.generativelanguage_v1beta3.services.discuss_service + :members: + :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/model_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/model_service.rst new file mode 100644 index 000000000000..92b24f45230b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/model_service.rst @@ -0,0 +1,10 @@ +ModelService +------------------------------ + +.. automodule:: google.ai.generativelanguage_v1beta3.services.model_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1beta3.services.model_service.pagers + :members: + :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/permission_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/permission_service.rst new file mode 100644 index 000000000000..e645aebb7270 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/permission_service.rst @@ -0,0 +1,10 @@ +PermissionService +----------------------------------- + +.. automodule:: google.ai.generativelanguage_v1beta3.services.permission_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1beta3.services.permission_service.pagers + :members: + :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/services.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/services.rst new file mode 100644 index 000000000000..377565194acd --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/services.rst @@ -0,0 +1,9 @@ +Services for Google Ai Generativelanguage v1beta3 API +===================================================== +.. toctree:: + :maxdepth: 2 + + discuss_service + model_service + permission_service + text_service diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/text_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/text_service.rst new file mode 100644 index 000000000000..cdc879c1ff11 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/text_service.rst @@ -0,0 +1,6 @@ +TextService +----------------------------- + +.. automodule:: google.ai.generativelanguage_v1beta3.services.text_service + :members: + :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/types.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/types.rst new file mode 100644 index 000000000000..4cfc946f42d4 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/types.rst @@ -0,0 +1,6 @@ +Types for Google Ai Generativelanguage v1beta3 API +================================================== + +.. automodule:: google.ai.generativelanguage_v1beta3.types + :members: + :show-inheritance: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst new file mode 100644 index 000000000000..d08223c1a59b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst @@ -0,0 +1,7 @@ +API Reference +------------- +.. toctree:: + :maxdepth: 2 + + generativelanguage_v1beta3/services + generativelanguage_v1beta3/types diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py new file mode 100644 index 000000000000..77d8cbc1869c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.ai.generativelanguage import gapic_version as package_version + +__version__ = package_version.__version__ + + +from google.ai.generativelanguage_v1beta3.services.discuss_service.client import DiscussServiceClient +from google.ai.generativelanguage_v1beta3.services.discuss_service.async_client import DiscussServiceAsyncClient +from google.ai.generativelanguage_v1beta3.services.model_service.client import ModelServiceClient +from google.ai.generativelanguage_v1beta3.services.model_service.async_client import ModelServiceAsyncClient +from google.ai.generativelanguage_v1beta3.services.permission_service.client import PermissionServiceClient +from google.ai.generativelanguage_v1beta3.services.permission_service.async_client import PermissionServiceAsyncClient +from google.ai.generativelanguage_v1beta3.services.text_service.client import TextServiceClient +from google.ai.generativelanguage_v1beta3.services.text_service.async_client import TextServiceAsyncClient + +from google.ai.generativelanguage_v1beta3.types.citation import CitationMetadata +from google.ai.generativelanguage_v1beta3.types.citation import CitationSource +from google.ai.generativelanguage_v1beta3.types.discuss_service import CountMessageTokensRequest +from google.ai.generativelanguage_v1beta3.types.discuss_service import CountMessageTokensResponse +from google.ai.generativelanguage_v1beta3.types.discuss_service import Example +from google.ai.generativelanguage_v1beta3.types.discuss_service import GenerateMessageRequest +from google.ai.generativelanguage_v1beta3.types.discuss_service import GenerateMessageResponse +from google.ai.generativelanguage_v1beta3.types.discuss_service import Message +from google.ai.generativelanguage_v1beta3.types.discuss_service import MessagePrompt +from google.ai.generativelanguage_v1beta3.types.model import Model +from google.ai.generativelanguage_v1beta3.types.model_service import CreateTunedModelMetadata +from google.ai.generativelanguage_v1beta3.types.model_service import CreateTunedModelRequest +from google.ai.generativelanguage_v1beta3.types.model_service import DeleteTunedModelRequest +from google.ai.generativelanguage_v1beta3.types.model_service import GetModelRequest +from google.ai.generativelanguage_v1beta3.types.model_service import GetTunedModelRequest +from google.ai.generativelanguage_v1beta3.types.model_service import ListModelsRequest +from google.ai.generativelanguage_v1beta3.types.model_service import ListModelsResponse +from google.ai.generativelanguage_v1beta3.types.model_service import ListTunedModelsRequest +from google.ai.generativelanguage_v1beta3.types.model_service import ListTunedModelsResponse +from google.ai.generativelanguage_v1beta3.types.model_service import UpdateTunedModelRequest +from google.ai.generativelanguage_v1beta3.types.permission import Permission +from google.ai.generativelanguage_v1beta3.types.permission_service import CreatePermissionRequest +from google.ai.generativelanguage_v1beta3.types.permission_service import DeletePermissionRequest +from google.ai.generativelanguage_v1beta3.types.permission_service import GetPermissionRequest +from google.ai.generativelanguage_v1beta3.types.permission_service import ListPermissionsRequest +from google.ai.generativelanguage_v1beta3.types.permission_service import ListPermissionsResponse +from google.ai.generativelanguage_v1beta3.types.permission_service import TransferOwnershipRequest +from google.ai.generativelanguage_v1beta3.types.permission_service import TransferOwnershipResponse +from google.ai.generativelanguage_v1beta3.types.permission_service import UpdatePermissionRequest +from google.ai.generativelanguage_v1beta3.types.safety import ContentFilter +from google.ai.generativelanguage_v1beta3.types.safety import SafetyFeedback +from google.ai.generativelanguage_v1beta3.types.safety import SafetyRating +from google.ai.generativelanguage_v1beta3.types.safety import SafetySetting +from google.ai.generativelanguage_v1beta3.types.safety import HarmCategory +from google.ai.generativelanguage_v1beta3.types.text_service import BatchEmbedTextRequest +from google.ai.generativelanguage_v1beta3.types.text_service import BatchEmbedTextResponse +from google.ai.generativelanguage_v1beta3.types.text_service import CountTextTokensRequest +from google.ai.generativelanguage_v1beta3.types.text_service import CountTextTokensResponse +from google.ai.generativelanguage_v1beta3.types.text_service import Embedding +from google.ai.generativelanguage_v1beta3.types.text_service import EmbedTextRequest +from google.ai.generativelanguage_v1beta3.types.text_service import EmbedTextResponse +from google.ai.generativelanguage_v1beta3.types.text_service import GenerateTextRequest +from google.ai.generativelanguage_v1beta3.types.text_service import GenerateTextResponse +from google.ai.generativelanguage_v1beta3.types.text_service import TextCompletion +from google.ai.generativelanguage_v1beta3.types.text_service import TextPrompt +from google.ai.generativelanguage_v1beta3.types.tuned_model import Dataset +from google.ai.generativelanguage_v1beta3.types.tuned_model import Hyperparameters +from google.ai.generativelanguage_v1beta3.types.tuned_model import TunedModel +from google.ai.generativelanguage_v1beta3.types.tuned_model import TunedModelSource +from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningExample +from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningExamples +from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningSnapshot +from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningTask + +__all__ = ('DiscussServiceClient', + 'DiscussServiceAsyncClient', + 'ModelServiceClient', + 'ModelServiceAsyncClient', + 'PermissionServiceClient', + 'PermissionServiceAsyncClient', + 'TextServiceClient', + 'TextServiceAsyncClient', + 'CitationMetadata', + 'CitationSource', + 'CountMessageTokensRequest', + 'CountMessageTokensResponse', + 'Example', + 'GenerateMessageRequest', + 'GenerateMessageResponse', + 'Message', + 'MessagePrompt', + 'Model', + 'CreateTunedModelMetadata', + 'CreateTunedModelRequest', + 'DeleteTunedModelRequest', + 'GetModelRequest', + 'GetTunedModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'ListTunedModelsRequest', + 'ListTunedModelsResponse', + 'UpdateTunedModelRequest', + 'Permission', + 'CreatePermissionRequest', + 'DeletePermissionRequest', + 'GetPermissionRequest', + 'ListPermissionsRequest', + 'ListPermissionsResponse', + 'TransferOwnershipRequest', + 'TransferOwnershipResponse', + 'UpdatePermissionRequest', + 'ContentFilter', + 'SafetyFeedback', + 'SafetyRating', + 'SafetySetting', + 'HarmCategory', + 'BatchEmbedTextRequest', + 'BatchEmbedTextResponse', + 'CountTextTokensRequest', + 'CountTextTokensResponse', + 'Embedding', + 'EmbedTextRequest', + 'EmbedTextResponse', + 'GenerateTextRequest', + 'GenerateTextResponse', + 'TextCompletion', + 'TextPrompt', + 'Dataset', + 'Hyperparameters', + 'TunedModel', + 'TunedModelSource', + 'TuningExample', + 'TuningExamples', + 'TuningSnapshot', + 'TuningTask', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py new file mode 100644 index 000000000000..360a0d13ebdd --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +__version__ = "0.0.0" # {x-release-please-version} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed new file mode 100644 index 000000000000..38773eee6363 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-ai-generativelanguage package uses inline types. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py new file mode 100644 index 000000000000..264895e674fe --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +__version__ = package_version.__version__ + + +from .services.discuss_service import DiscussServiceClient +from .services.discuss_service import DiscussServiceAsyncClient +from .services.model_service import ModelServiceClient +from .services.model_service import ModelServiceAsyncClient +from .services.permission_service import PermissionServiceClient +from .services.permission_service import PermissionServiceAsyncClient +from .services.text_service import TextServiceClient +from .services.text_service import TextServiceAsyncClient + +from .types.citation import CitationMetadata +from .types.citation import CitationSource +from .types.discuss_service import CountMessageTokensRequest +from .types.discuss_service import CountMessageTokensResponse +from .types.discuss_service import Example +from .types.discuss_service import GenerateMessageRequest +from .types.discuss_service import GenerateMessageResponse +from .types.discuss_service import Message +from .types.discuss_service import MessagePrompt +from .types.model import Model +from .types.model_service import CreateTunedModelMetadata +from .types.model_service import CreateTunedModelRequest +from .types.model_service import DeleteTunedModelRequest +from .types.model_service import GetModelRequest +from .types.model_service import GetTunedModelRequest +from .types.model_service import ListModelsRequest +from .types.model_service import ListModelsResponse +from .types.model_service import ListTunedModelsRequest +from .types.model_service import ListTunedModelsResponse +from .types.model_service import UpdateTunedModelRequest +from .types.permission import Permission +from .types.permission_service import CreatePermissionRequest +from .types.permission_service import DeletePermissionRequest +from .types.permission_service import GetPermissionRequest +from .types.permission_service import ListPermissionsRequest +from .types.permission_service import ListPermissionsResponse +from .types.permission_service import TransferOwnershipRequest +from .types.permission_service import TransferOwnershipResponse +from .types.permission_service import UpdatePermissionRequest +from .types.safety import ContentFilter +from .types.safety import SafetyFeedback +from .types.safety import SafetyRating +from .types.safety import SafetySetting +from .types.safety import HarmCategory +from .types.text_service import BatchEmbedTextRequest +from .types.text_service import BatchEmbedTextResponse +from .types.text_service import CountTextTokensRequest +from .types.text_service import CountTextTokensResponse +from .types.text_service import Embedding +from .types.text_service import EmbedTextRequest +from .types.text_service import EmbedTextResponse +from .types.text_service import GenerateTextRequest +from .types.text_service import GenerateTextResponse +from .types.text_service import TextCompletion +from .types.text_service import TextPrompt +from .types.tuned_model import Dataset +from .types.tuned_model import Hyperparameters +from .types.tuned_model import TunedModel +from .types.tuned_model import TunedModelSource +from .types.tuned_model import TuningExample +from .types.tuned_model import TuningExamples +from .types.tuned_model import TuningSnapshot +from .types.tuned_model import TuningTask + +__all__ = ( + 'DiscussServiceAsyncClient', + 'ModelServiceAsyncClient', + 'PermissionServiceAsyncClient', + 'TextServiceAsyncClient', +'BatchEmbedTextRequest', +'BatchEmbedTextResponse', +'CitationMetadata', +'CitationSource', +'ContentFilter', +'CountMessageTokensRequest', +'CountMessageTokensResponse', +'CountTextTokensRequest', +'CountTextTokensResponse', +'CreatePermissionRequest', +'CreateTunedModelMetadata', +'CreateTunedModelRequest', +'Dataset', +'DeletePermissionRequest', +'DeleteTunedModelRequest', +'DiscussServiceClient', +'EmbedTextRequest', +'EmbedTextResponse', +'Embedding', +'Example', +'GenerateMessageRequest', +'GenerateMessageResponse', +'GenerateTextRequest', +'GenerateTextResponse', +'GetModelRequest', +'GetPermissionRequest', +'GetTunedModelRequest', +'HarmCategory', +'Hyperparameters', +'ListModelsRequest', +'ListModelsResponse', +'ListPermissionsRequest', +'ListPermissionsResponse', +'ListTunedModelsRequest', +'ListTunedModelsResponse', +'Message', +'MessagePrompt', +'Model', +'ModelServiceClient', +'Permission', +'PermissionServiceClient', +'SafetyFeedback', +'SafetyRating', +'SafetySetting', +'TextCompletion', +'TextPrompt', +'TextServiceClient', +'TransferOwnershipRequest', +'TransferOwnershipResponse', +'TunedModel', +'TunedModelSource', +'TuningExample', +'TuningExamples', +'TuningSnapshot', +'TuningTask', +'UpdatePermissionRequest', +'UpdateTunedModelRequest', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_metadata.json b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_metadata.json new file mode 100644 index 000000000000..a7ee1b83ccd5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_metadata.json @@ -0,0 +1,370 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.ai.generativelanguage_v1beta3", + "protoPackage": "google.ai.generativelanguage.v1beta3", + "schema": "1.0", + "services": { + "DiscussService": { + "clients": { + "grpc": { + "libraryClient": "DiscussServiceClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + }, + "grpc-async": { + "libraryClient": "DiscussServiceAsyncClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + }, + "rest": { + "libraryClient": "DiscussServiceClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + } + } + }, + "ModelService": { + "clients": { + "grpc": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "CreateTunedModel": { + "methods": [ + "create_tuned_model" + ] + }, + "DeleteTunedModel": { + "methods": [ + "delete_tuned_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetTunedModel": { + "methods": [ + "get_tuned_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "ListTunedModels": { + "methods": [ + "list_tuned_models" + ] + }, + "UpdateTunedModel": { + "methods": [ + "update_tuned_model" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ModelServiceAsyncClient", + "rpcs": { + "CreateTunedModel": { + "methods": [ + "create_tuned_model" + ] + }, + "DeleteTunedModel": { + "methods": [ + "delete_tuned_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetTunedModel": { + "methods": [ + "get_tuned_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "ListTunedModels": { + "methods": [ + "list_tuned_models" + ] + }, + "UpdateTunedModel": { + "methods": [ + "update_tuned_model" + ] + } + } + }, + "rest": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "CreateTunedModel": { + "methods": [ + "create_tuned_model" + ] + }, + "DeleteTunedModel": { + "methods": [ + "delete_tuned_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetTunedModel": { + "methods": [ + "get_tuned_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "ListTunedModels": { + "methods": [ + "list_tuned_models" + ] + }, + "UpdateTunedModel": { + "methods": [ + "update_tuned_model" + ] + } + } + } + } + }, + "PermissionService": { + "clients": { + "grpc": { + "libraryClient": "PermissionServiceClient", + "rpcs": { + "CreatePermission": { + "methods": [ + "create_permission" + ] + }, + "DeletePermission": { + "methods": [ + "delete_permission" + ] + }, + "GetPermission": { + "methods": [ + "get_permission" + ] + }, + "ListPermissions": { + "methods": [ + "list_permissions" + ] + }, + "TransferOwnership": { + "methods": [ + "transfer_ownership" + ] + }, + "UpdatePermission": { + "methods": [ + "update_permission" + ] + } + } + }, + "grpc-async": { + "libraryClient": "PermissionServiceAsyncClient", + "rpcs": { + "CreatePermission": { + "methods": [ + "create_permission" + ] + }, + "DeletePermission": { + "methods": [ + "delete_permission" + ] + }, + "GetPermission": { + "methods": [ + "get_permission" + ] + }, + "ListPermissions": { + "methods": [ + "list_permissions" + ] + }, + "TransferOwnership": { + "methods": [ + "transfer_ownership" + ] + }, + "UpdatePermission": { + "methods": [ + "update_permission" + ] + } + } + }, + "rest": { + "libraryClient": "PermissionServiceClient", + "rpcs": { + "CreatePermission": { + "methods": [ + "create_permission" + ] + }, + "DeletePermission": { + "methods": [ + "delete_permission" + ] + }, + "GetPermission": { + "methods": [ + "get_permission" + ] + }, + "ListPermissions": { + "methods": [ + "list_permissions" + ] + }, + "TransferOwnership": { + "methods": [ + "transfer_ownership" + ] + }, + "UpdatePermission": { + "methods": [ + "update_permission" + ] + } + } + } + } + }, + "TextService": { + "clients": { + "grpc": { + "libraryClient": "TextServiceClient", + "rpcs": { + "BatchEmbedText": { + "methods": [ + "batch_embed_text" + ] + }, + "CountTextTokens": { + "methods": [ + "count_text_tokens" + ] + }, + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + }, + "grpc-async": { + "libraryClient": "TextServiceAsyncClient", + "rpcs": { + "BatchEmbedText": { + "methods": [ + "batch_embed_text" + ] + }, + "CountTextTokens": { + "methods": [ + "count_text_tokens" + ] + }, + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + }, + "rest": { + "libraryClient": "TextServiceClient", + "rpcs": { + "BatchEmbedText": { + "methods": [ + "batch_embed_text" + ] + }, + "CountTextTokens": { + "methods": [ + "count_text_tokens" + ] + }, + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + } + } + } + } +} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py new file mode 100644 index 000000000000..360a0d13ebdd --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +__version__ = "0.0.0" # {x-release-please-version} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed new file mode 100644 index 000000000000..38773eee6363 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-ai-generativelanguage package uses inline types. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/__init__.py new file mode 100644 index 000000000000..89a37dc92c5a --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py new file mode 100644 index 000000000000..c5c6e8208269 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import DiscussServiceClient +from .async_client import DiscussServiceAsyncClient + +__all__ = ( + 'DiscussServiceClient', + 'DiscussServiceAsyncClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py new file mode 100644 index 000000000000..1f9cde10aa21 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py @@ -0,0 +1,509 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.types import discuss_service +from google.ai.generativelanguage_v1beta3.types import safety +from google.longrunning import operations_pb2 # type: ignore +from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .client import DiscussServiceClient + + +class DiscussServiceAsyncClient: + """An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + """ + + _client: DiscussServiceClient + + DEFAULT_ENDPOINT = DiscussServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = DiscussServiceClient.DEFAULT_MTLS_ENDPOINT + + model_path = staticmethod(DiscussServiceClient.model_path) + parse_model_path = staticmethod(DiscussServiceClient.parse_model_path) + common_billing_account_path = staticmethod(DiscussServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(DiscussServiceClient.parse_common_billing_account_path) + common_folder_path = staticmethod(DiscussServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(DiscussServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(DiscussServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(DiscussServiceClient.parse_common_organization_path) + common_project_path = staticmethod(DiscussServiceClient.common_project_path) + parse_common_project_path = staticmethod(DiscussServiceClient.parse_common_project_path) + common_location_path = staticmethod(DiscussServiceClient.common_location_path) + parse_common_location_path = staticmethod(DiscussServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceAsyncClient: The constructed client. + """ + return DiscussServiceClient.from_service_account_info.__func__(DiscussServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceAsyncClient: The constructed client. + """ + return DiscussServiceClient.from_service_account_file.__func__(DiscussServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return DiscussServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> DiscussServiceTransport: + """Returns the transport used by the client instance. + + Returns: + DiscussServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(DiscussServiceClient).get_transport_class, type(DiscussServiceClient)) + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, DiscussServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the discuss service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.DiscussServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = DiscussServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def generate_message(self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.GenerateMessageResponse: + r"""Generates a response from the model given an input + ``MessagePrompt``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_message(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.GenerateMessageRequest, dict]]): + The request object. Request to generate a message + response from the model. + model (:class:`str`): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1beta3.types.MessagePrompt`): + Required. The structured textual + input given to the model as a prompt. + Given a + prompt, the model will return what it + predicts is the next message in the + discussion. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (:class:`float`): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (:class:`int`): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If + unset, this will default to ``1``. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (:class:`float`): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens + whose probability sum is at least ``top_p``. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (:class:`int`): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = discuss_service.GenerateMessageRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.generate_message, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def count_message_tokens(self, + request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.CountMessageTokensResponse: + r"""Runs a model's tokenizer on a string and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_message_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.CountMessageTokensRequest, dict]]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1beta3.types.MessagePrompt`): + Required. The prompt, whose token + count is to be returned. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.CountMessageTokensResponse: + A response from CountMessageTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = discuss_service.CountMessageTokensRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.count_message_tokens, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "DiscussServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "DiscussServiceAsyncClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py new file mode 100644 index 000000000000..1e3b5952d0bf --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py @@ -0,0 +1,717 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.types import discuss_service +from google.ai.generativelanguage_v1beta3.types import safety +from google.longrunning import operations_pb2 # type: ignore +from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import DiscussServiceGrpcTransport +from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .transports.rest import DiscussServiceRestTransport + + +class DiscussServiceClientMeta(type): + """Metaclass for the DiscussService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] + _transport_registry["grpc"] = DiscussServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DiscussServiceGrpcAsyncIOTransport + _transport_registry["rest"] = DiscussServiceRestTransport + + def get_transport_class(cls, + label: Optional[str] = None, + ) -> Type[DiscussServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class DiscussServiceClient(metaclass=DiscussServiceClientMeta): + """An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> DiscussServiceTransport: + """Returns the transport used by the client instance. + + Returns: + DiscussServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path(model: str,) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format(model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, DiscussServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the discuss service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, DiscussServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError("client_options.api_key and credentials are mutually exclusive") + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, DiscussServiceTransport): + # transport is a DiscussServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError("When providing a transport instance, " + "provide its credentials directly.") + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): + credentials = google.auth._default.get_api_key_credentials(api_key_value) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def generate_message(self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.GenerateMessageResponse: + r"""Generates a response from the model given an input + ``MessagePrompt``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_message(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.GenerateMessageRequest, dict]): + The request object. Request to generate a message + response from the model. + model (str): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1beta3.types.MessagePrompt): + Required. The structured textual + input given to the model as a prompt. + Given a + prompt, the model will return what it + predicts is the next message in the + discussion. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (int): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If + unset, this will default to ``1``. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (float): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens + whose probability sum is at least ``top_p``. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a discuss_service.GenerateMessageRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, discuss_service.GenerateMessageRequest): + request = discuss_service.GenerateMessageRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_message] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def count_message_tokens(self, + request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.CountMessageTokensResponse: + r"""Runs a model's tokenizer on a string and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_message_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.CountMessageTokensRequest, dict]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1beta3.types.MessagePrompt): + Required. The prompt, whose token + count is to be returned. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.CountMessageTokensResponse: + A response from CountMessageTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a discuss_service.CountMessageTokensRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, discuss_service.CountMessageTokensRequest): + request = discuss_service.CountMessageTokensRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.count_message_tokens] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "DiscussServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + + + + + + + + + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "DiscussServiceClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py new file mode 100644 index 000000000000..b585c1ce424c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import DiscussServiceTransport +from .grpc import DiscussServiceGrpcTransport +from .grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .rest import DiscussServiceRestTransport +from .rest import DiscussServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] +_transport_registry['grpc'] = DiscussServiceGrpcTransport +_transport_registry['grpc_asyncio'] = DiscussServiceGrpcAsyncIOTransport +_transport_registry['rest'] = DiscussServiceRestTransport + +__all__ = ( + 'DiscussServiceTransport', + 'DiscussServiceGrpcTransport', + 'DiscussServiceGrpcAsyncIOTransport', + 'DiscussServiceRestTransport', + 'DiscussServiceRestInterceptor', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py new file mode 100644 index 000000000000..7c455e9f245e --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3.types import discuss_service +from google.longrunning import operations_pb2 # type: ignore + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +class DiscussServiceTransport(abc.ABC): + """Abstract transport class for DiscussService.""" + + AUTH_SCOPES = ( + ) + + DEFAULT_HOST: str = 'generativelanguage.googleapis.com' + def __init__( + self, *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, + **scopes_kwargs, + quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.generate_message: gapic_v1.method.wrap_method( + self.generate_message, + default_timeout=None, + client_info=client_info, + ), + self.count_message_tokens: gapic_v1.method.wrap_method( + self.count_message_tokens, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + Union[ + discuss_service.GenerateMessageResponse, + Awaitable[discuss_service.GenerateMessageResponse] + ]]: + raise NotImplementedError() + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Union[ + discuss_service.CountMessageTokensResponse, + Awaitable[discuss_service.CountMessageTokensResponse] + ]]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ( + 'DiscussServiceTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py new file mode 100644 index 000000000000..3e6abae06b98 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py @@ -0,0 +1,296 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.ai.generativelanguage_v1beta3.types import discuss_service +from google.longrunning import operations_pb2 # type: ignore +from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO + + +class DiscussServiceGrpcTransport(DiscussServiceTransport): + """gRPC backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse]: + r"""Return a callable for the generate message method over gRPC. + + Generates a response from the model given an input + ``MessagePrompt``. + + Returns: + Callable[[~.GenerateMessageRequest], + ~.GenerateMessageResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_message' not in self._stubs: + self._stubs['generate_message'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.DiscussService/GenerateMessage', + request_serializer=discuss_service.GenerateMessageRequest.serialize, + response_deserializer=discuss_service.GenerateMessageResponse.deserialize, + ) + return self._stubs['generate_message'] + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse]: + r"""Return a callable for the count message tokens method over gRPC. + + Runs a model's tokenizer on a string and returns the + token count. + + Returns: + Callable[[~.CountMessageTokensRequest], + ~.CountMessageTokensResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'count_message_tokens' not in self._stubs: + self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.DiscussService/CountMessageTokens', + request_serializer=discuss_service.CountMessageTokensRequest.serialize, + response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, + ) + return self._stubs['count_message_tokens'] + + def close(self): + self.grpc_channel.close() + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ( + 'DiscussServiceGrpcTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..48e36ab4d7ad --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py @@ -0,0 +1,295 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.ai.generativelanguage_v1beta3.types import discuss_service +from google.longrunning import operations_pb2 # type: ignore +from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import DiscussServiceGrpcTransport + + +class DiscussServiceGrpcAsyncIOTransport(DiscussServiceTransport): + """gRPC AsyncIO backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + Awaitable[discuss_service.GenerateMessageResponse]]: + r"""Return a callable for the generate message method over gRPC. + + Generates a response from the model given an input + ``MessagePrompt``. + + Returns: + Callable[[~.GenerateMessageRequest], + Awaitable[~.GenerateMessageResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_message' not in self._stubs: + self._stubs['generate_message'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.DiscussService/GenerateMessage', + request_serializer=discuss_service.GenerateMessageRequest.serialize, + response_deserializer=discuss_service.GenerateMessageResponse.deserialize, + ) + return self._stubs['generate_message'] + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Awaitable[discuss_service.CountMessageTokensResponse]]: + r"""Return a callable for the count message tokens method over gRPC. + + Runs a model's tokenizer on a string and returns the + token count. + + Returns: + Callable[[~.CountMessageTokensRequest], + Awaitable[~.CountMessageTokensResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'count_message_tokens' not in self._stubs: + self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.DiscussService/CountMessageTokens', + request_serializer=discuss_service.CountMessageTokensRequest.serialize, + response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, + ) + return self._stubs['count_message_tokens'] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ( + 'DiscussServiceGrpcAsyncIOTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py new file mode 100644 index 000000000000..0585ca398116 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.ai.generativelanguage_v1beta3.types import discuss_service +from google.longrunning import operations_pb2 # type: ignore + +from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class DiscussServiceRestInterceptor: + """Interceptor for DiscussService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the DiscussServiceRestTransport. + + .. code-block:: python + class MyCustomDiscussServiceInterceptor(DiscussServiceRestInterceptor): + def pre_count_message_tokens(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_count_message_tokens(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_message(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_message(self, response): + logging.log(f"Received response: {response}") + return response + + transport = DiscussServiceRestTransport(interceptor=MyCustomDiscussServiceInterceptor()) + client = DiscussServiceClient(transport=transport) + + + """ + def pre_count_message_tokens(self, request: discuss_service.CountMessageTokensRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.CountMessageTokensRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for count_message_tokens + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_count_message_tokens(self, response: discuss_service.CountMessageTokensResponse) -> discuss_service.CountMessageTokensResponse: + """Post-rpc interceptor for count_message_tokens + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + def pre_generate_message(self, request: discuss_service.GenerateMessageRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.GenerateMessageRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for generate_message + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_generate_message(self, response: discuss_service.GenerateMessageResponse) -> discuss_service.GenerateMessageResponse: + """Post-rpc interceptor for generate_message + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class DiscussServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: DiscussServiceRestInterceptor + + +class DiscussServiceRestTransport(DiscussServiceTransport): + """REST backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[ + ], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = 'https', + interceptor: Optional[DiscussServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or DiscussServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CountMessageTokens(DiscussServiceRestStub): + def __hash__(self): + return hash("CountMessageTokens") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: discuss_service.CountMessageTokensRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> discuss_service.CountMessageTokensResponse: + r"""Call the count message tokens method over HTTP. + + Args: + request (~.discuss_service.CountMessageTokensRequest): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.discuss_service.CountMessageTokensResponse: + A response from ``CountMessageTokens``. + + It returns the model's ``token_count`` for the + ``prompt``. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{model=models/*}:countMessageTokens', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_count_message_tokens(request, metadata) + pb_request = discuss_service.CountMessageTokensRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = discuss_service.CountMessageTokensResponse() + pb_resp = discuss_service.CountMessageTokensResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_count_message_tokens(resp) + return resp + + class _GenerateMessage(DiscussServiceRestStub): + def __hash__(self): + return hash("GenerateMessage") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: discuss_service.GenerateMessageRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> discuss_service.GenerateMessageResponse: + r"""Call the generate message method over HTTP. + + Args: + request (~.discuss_service.GenerateMessageRequest): + The request object. Request to generate a message + response from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.discuss_service.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{model=models/*}:generateMessage', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_generate_message(request, metadata) + pb_request = discuss_service.GenerateMessageRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = discuss_service.GenerateMessageResponse() + pb_resp = discuss_service.GenerateMessageResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_generate_message(resp) + return resp + + @property + def count_message_tokens(self) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CountMessageTokens(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_message(self) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateMessage(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__=( + 'DiscussServiceRestTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py new file mode 100644 index 000000000000..2c368b92d844 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import ModelServiceClient +from .async_client import ModelServiceAsyncClient + +__all__ = ( + 'ModelServiceClient', + 'ModelServiceAsyncClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py new file mode 100644 index 000000000000..0759ce5d1845 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py @@ -0,0 +1,996 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.services.model_service import pagers +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .client import ModelServiceClient + + +class ModelServiceAsyncClient: + """Provides methods for getting metadata information about + Generative Models. + """ + + _client: ModelServiceClient + + DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT + + model_path = staticmethod(ModelServiceClient.model_path) + parse_model_path = staticmethod(ModelServiceClient.parse_model_path) + tuned_model_path = staticmethod(ModelServiceClient.tuned_model_path) + parse_tuned_model_path = staticmethod(ModelServiceClient.parse_tuned_model_path) + common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) + common_folder_path = staticmethod(ModelServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(ModelServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) + common_project_path = staticmethod(ModelServiceClient.common_project_path) + parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) + common_location_path = staticmethod(ModelServiceClient.common_location_path) + parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_info.__func__(ModelServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_file.__func__(ModelServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return ModelServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> ModelServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = ModelServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def get_model(self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets information about a specific Model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_get_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.GetModelRequest, dict]]): + The request object. Request for getting information about + a specific Model. + name (:class:`str`): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Model: + Information about a Generative + Language Model. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.GetModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_models(self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: + r"""Lists models available through the API. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_list_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.ListModelsRequest, dict]]): + The request object. Request for listing all Models. + page_size (:class:`int`): + The maximum number of ``Models`` to return (per page). + + The service may return fewer models. If unspecified, at + most 50 models will be returned per page. This method + returns at most 1000 models per page, even if you pass a + larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (:class:`str`): + A page token, received from a previous ``ListModels`` + call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the + page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListModelsAsyncPager: + Response from ListModel containing a paginated list of + Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.ListModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_models, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_tuned_model(self, + request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tuned_model.TunedModel: + r"""Gets information about a specific TunedModel. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.GetTunedModelRequest, dict]]): + The request object. Request for getting information about + a specific Model. + name (:class:`str`): + Required. The resource name of the model. + + Format: ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.GetTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_tuned_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_tuned_models(self, + request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTunedModelsAsyncPager: + r"""Lists tuned models owned by the user. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.ListTunedModelsRequest, dict]]): + The request object. Request for listing TunedModels. + page_size (:class:`int`): + Optional. The maximum number of ``TunedModels`` to + return (per page). The service may return fewer tuned + models. + + If unspecified, at most 10 tuned models will be + returned. This method returns at most 1000 models per + page, even if you pass a larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (:class:`str`): + Optional. A page token, received from a previous + ``ListTunedModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListTunedModels`` must match the call that provided + the page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListTunedModelsAsyncPager: + Response from ListTunedModels containing a paginated + list of Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.ListTunedModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_tuned_models, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTunedModelsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_tuned_model(self, + request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + tuned_model_id: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a tuned model. Intermediate tuning progress (if any) is + accessed through the [google.longrunning.Operations] service. + + Status and results can be accessed through the Operations + service. Example: GET + /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.CreateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.CreateTunedModelRequest, dict]]): + The request object. Request to create a TunedModel. + tuned_model (:class:`google.ai.generativelanguage_v1beta3.types.TunedModel`): + Required. The tuned model to create. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tuned_model_id (:class:`str`): + Optional. The unique id for the tuned model if + specified. This value should be up to 40 characters, the + first character must be a letter, the last could be a + letter or a number. The id must match the regular + expression: `a-z <[a-z0-9-]{0,38}[a-z0-9]>`__?. + + This corresponds to the ``tuned_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.ai.generativelanguage_v1beta3.types.TunedModel` + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, tuned_model_id]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.CreateTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if tuned_model_id is not None: + request.tuned_model_id = tuned_model_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_tuned_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gag_tuned_model.TunedModel, + metadata_type=model_service.CreateTunedModelMetadata, + ) + + # Done; return the response. + return response + + async def update_tuned_model(self, + request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_tuned_model.TunedModel: + r"""Updates a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.UpdateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + response = await client.update_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.UpdateTunedModelRequest, dict]]): + The request object. Request to update a TunedModel. + tuned_model (:class:`google.ai.generativelanguage_v1beta3.types.TunedModel`): + Required. The tuned model to update. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The list of fields to + update. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.UpdateTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_tuned_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("tuned_model.name", request.tuned_model.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_tuned_model(self, + request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + await client.delete_tuned_model(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.DeleteTunedModelRequest, dict]]): + The request object. Request to delete a TunedModel. + name (:class:`str`): + Required. The resource name of the model. Format: + ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = model_service.DeleteTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_tuned_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def __aenter__(self) -> "ModelServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "ModelServiceAsyncClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/client.py new file mode 100644 index 000000000000..d64fa37b6f6f --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/client.py @@ -0,0 +1,1213 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.services.model_service import pagers +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import ModelServiceGrpcTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .transports.rest import ModelServiceRestTransport + + +class ModelServiceClientMeta(type): + """Metaclass for the ModelService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport + _transport_registry["rest"] = ModelServiceRestTransport + + def get_transport_class(cls, + label: Optional[str] = None, + ) -> Type[ModelServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class ModelServiceClient(metaclass=ModelServiceClientMeta): + """Provides methods for getting metadata information about + Generative Models. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path(model: str,) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format(model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def tuned_model_path(tuned_model: str,) -> str: + """Returns a fully-qualified tuned_model string.""" + return "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + + @staticmethod + def parse_tuned_model_path(path: str) -> Dict[str,str]: + """Parses a tuned_model path into its component segments.""" + m = re.match(r"^tunedModels/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, ModelServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError("client_options.api_key and credentials are mutually exclusive") + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, ModelServiceTransport): + # transport is a ModelServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError("When providing a transport instance, " + "provide its credentials directly.") + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): + credentials = google.auth._default.get_api_key_credentials(api_key_value) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def get_model(self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets information about a specific Model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_get_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.GetModelRequest, dict]): + The request object. Request for getting information about + a specific Model. + name (str): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Model: + Information about a Generative + Language Model. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_models(self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: + r"""Lists models available through the API. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_list_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.ListModelsRequest, dict]): + The request object. Request for listing all Models. + page_size (int): + The maximum number of ``Models`` to return (per page). + + The service may return fewer models. If unspecified, at + most 50 models will be returned per page. This method + returns at most 1000 models per page, even if you pass a + larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (str): + A page token, received from a previous ``ListModels`` + call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the + page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListModelsPager: + Response from ListModel containing a paginated list of + Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_models] + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListModelsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_tuned_model(self, + request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tuned_model.TunedModel: + r"""Gets information about a specific TunedModel. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.GetTunedModelRequest, dict]): + The request object. Request for getting information about + a specific Model. + name (str): + Required. The resource name of the model. + + Format: ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetTunedModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetTunedModelRequest): + request = model_service.GetTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_tuned_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_tuned_models(self, + request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTunedModelsPager: + r"""Lists tuned models owned by the user. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.ListTunedModelsRequest, dict]): + The request object. Request for listing TunedModels. + page_size (int): + Optional. The maximum number of ``TunedModels`` to + return (per page). The service may return fewer tuned + models. + + If unspecified, at most 10 tuned models will be + returned. This method returns at most 1000 models per + page, even if you pass a larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (str): + Optional. A page token, received from a previous + ``ListTunedModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListTunedModels`` must match the call that provided + the page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListTunedModelsPager: + Response from ListTunedModels containing a paginated + list of Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListTunedModelsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListTunedModelsRequest): + request = model_service.ListTunedModelsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_tuned_models] + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTunedModelsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_tuned_model(self, + request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + tuned_model_id: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Creates a tuned model. Intermediate tuning progress (if any) is + accessed through the [google.longrunning.Operations] service. + + Status and results can be accessed through the Operations + service. Example: GET + /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.CreateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.CreateTunedModelRequest, dict]): + The request object. Request to create a TunedModel. + tuned_model (google.ai.generativelanguage_v1beta3.types.TunedModel): + Required. The tuned model to create. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tuned_model_id (str): + Optional. The unique id for the tuned model if + specified. This value should be up to 40 characters, the + first character must be a letter, the last could be a + letter or a number. The id must match the regular + expression: `a-z <[a-z0-9-]{0,38}[a-z0-9]>`__?. + + This corresponds to the ``tuned_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.ai.generativelanguage_v1beta3.types.TunedModel` + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, tuned_model_id]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.CreateTunedModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.CreateTunedModelRequest): + request = model_service.CreateTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if tuned_model_id is not None: + request.tuned_model_id = tuned_model_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_tuned_model] + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + gag_tuned_model.TunedModel, + metadata_type=model_service.CreateTunedModelMetadata, + ) + + # Done; return the response. + return response + + def update_tuned_model(self, + request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_tuned_model.TunedModel: + r"""Updates a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.UpdateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + response = client.update_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.UpdateTunedModelRequest, dict]): + The request object. Request to update a TunedModel. + tuned_model (google.ai.generativelanguage_v1beta3.types.TunedModel): + Required. The tuned model to update. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to + update. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.UpdateTunedModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.UpdateTunedModelRequest): + request = model_service.UpdateTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_tuned_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("tuned_model.name", request.tuned_model.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_tuned_model(self, + request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + client.delete_tuned_model(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.DeleteTunedModelRequest, dict]): + The request object. Request to delete a TunedModel. + name (str): + Required. The resource name of the model. Format: + ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.DeleteTunedModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.DeleteTunedModelRequest): + request = model_service.DeleteTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_tuned_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def __enter__(self) -> "ModelServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + + + + + + + + + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "ModelServiceClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py new file mode 100644 index 000000000000..ede1634e0b87 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, AsyncIterator, Awaitable, Callable, Sequence, Tuple, Optional, Iterator + +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model + + +class ListModelsPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta3.types.ListModelsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta3.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta3.types.ListModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta3.types.ListModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[model.Model]: + for page in self.pages: + yield from page.models + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListModelsAsyncPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta3.types.ListModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta3.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta3.types.ListModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta3.types.ListModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + def __aiter__(self) -> AsyncIterator[model.Model]: + async def async_generator(): + async for page in self.pages: + for response in page.models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListTunedModelsPager: + """A pager for iterating through ``list_tuned_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta3.types.ListTunedModelsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``tuned_models`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTunedModels`` requests and continue to iterate + through the ``tuned_models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta3.types.ListTunedModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., model_service.ListTunedModelsResponse], + request: model_service.ListTunedModelsRequest, + response: model_service.ListTunedModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta3.types.ListTunedModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta3.types.ListTunedModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListTunedModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[model_service.ListTunedModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[tuned_model.TunedModel]: + for page in self.pages: + yield from page.tuned_models + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListTunedModelsAsyncPager: + """A pager for iterating through ``list_tuned_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta3.types.ListTunedModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``tuned_models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTunedModels`` requests and continue to iterate + through the ``tuned_models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta3.types.ListTunedModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[model_service.ListTunedModelsResponse]], + request: model_service.ListTunedModelsRequest, + response: model_service.ListTunedModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta3.types.ListTunedModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta3.types.ListTunedModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListTunedModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[model_service.ListTunedModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + def __aiter__(self) -> AsyncIterator[tuned_model.TunedModel]: + async def async_generator(): + async for page in self.pages: + for response in page.tuned_models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py new file mode 100644 index 000000000000..c51cadf4ba09 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import ModelServiceTransport +from .grpc import ModelServiceGrpcTransport +from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .rest import ModelServiceRestTransport +from .rest import ModelServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] +_transport_registry['grpc'] = ModelServiceGrpcTransport +_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport +_transport_registry['rest'] = ModelServiceRestTransport + +__all__ = ( + 'ModelServiceTransport', + 'ModelServiceGrpcTransport', + 'ModelServiceGrpcAsyncIOTransport', + 'ModelServiceRestTransport', + 'ModelServiceRestInterceptor', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py new file mode 100644 index 000000000000..dcc5074c3ae7 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import operations_v1 +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +class ModelServiceTransport(abc.ABC): + """Abstract transport class for ModelService.""" + + AUTH_SCOPES = ( + ) + + DEFAULT_HOST: str = 'generativelanguage.googleapis.com' + def __init__( + self, *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, + **scopes_kwargs, + quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.get_model: gapic_v1.method.wrap_method( + self.get_model, + default_timeout=None, + client_info=client_info, + ), + self.list_models: gapic_v1.method.wrap_method( + self.list_models, + default_timeout=None, + client_info=client_info, + ), + self.get_tuned_model: gapic_v1.method.wrap_method( + self.get_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.list_tuned_models: gapic_v1.method.wrap_method( + self.list_tuned_models, + default_timeout=None, + client_info=client_info, + ), + self.create_tuned_model: gapic_v1.method.wrap_method( + self.create_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.update_tuned_model: gapic_v1.method.wrap_method( + self.update_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.delete_tuned_model: gapic_v1.method.wrap_method( + self.delete_tuned_model, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def operations_client(self): + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + Union[ + model.Model, + Awaitable[model.Model] + ]]: + raise NotImplementedError() + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + Union[ + model_service.ListModelsResponse, + Awaitable[model_service.ListModelsResponse] + ]]: + raise NotImplementedError() + + @property + def get_tuned_model(self) -> Callable[ + [model_service.GetTunedModelRequest], + Union[ + tuned_model.TunedModel, + Awaitable[tuned_model.TunedModel] + ]]: + raise NotImplementedError() + + @property + def list_tuned_models(self) -> Callable[ + [model_service.ListTunedModelsRequest], + Union[ + model_service.ListTunedModelsResponse, + Awaitable[model_service.ListTunedModelsResponse] + ]]: + raise NotImplementedError() + + @property + def create_tuned_model(self) -> Callable[ + [model_service.CreateTunedModelRequest], + Union[ + operations_pb2.Operation, + Awaitable[operations_pb2.Operation] + ]]: + raise NotImplementedError() + + @property + def update_tuned_model(self) -> Callable[ + [model_service.UpdateTunedModelRequest], + Union[ + gag_tuned_model.TunedModel, + Awaitable[gag_tuned_model.TunedModel] + ]]: + raise NotImplementedError() + + @property + def delete_tuned_model(self) -> Callable[ + [model_service.DeleteTunedModelRequest], + Union[ + empty_pb2.Empty, + Awaitable[empty_pb2.Empty] + ]]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ( + 'ModelServiceTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py new file mode 100644 index 000000000000..c6a0e54df2da --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py @@ -0,0 +1,449 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import operations_v1 +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO + + +class ModelServiceGrpcTransport(ModelServiceTransport): + """gRPC backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Quick check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + model.Model]: + r"""Return a callable for the get model method over gRPC. + + Gets information about a specific Model. + + Returns: + Callable[[~.GetModelRequest], + ~.Model]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/GetModel', + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs['get_model'] + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + model_service.ListModelsResponse]: + r"""Return a callable for the list models method over gRPC. + + Lists models available through the API. + + Returns: + Callable[[~.ListModelsRequest], + ~.ListModelsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/ListModels', + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs['list_models'] + + @property + def get_tuned_model(self) -> Callable[ + [model_service.GetTunedModelRequest], + tuned_model.TunedModel]: + r"""Return a callable for the get tuned model method over gRPC. + + Gets information about a specific TunedModel. + + Returns: + Callable[[~.GetTunedModelRequest], + ~.TunedModel]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_tuned_model' not in self._stubs: + self._stubs['get_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/GetTunedModel', + request_serializer=model_service.GetTunedModelRequest.serialize, + response_deserializer=tuned_model.TunedModel.deserialize, + ) + return self._stubs['get_tuned_model'] + + @property + def list_tuned_models(self) -> Callable[ + [model_service.ListTunedModelsRequest], + model_service.ListTunedModelsResponse]: + r"""Return a callable for the list tuned models method over gRPC. + + Lists tuned models owned by the user. + + Returns: + Callable[[~.ListTunedModelsRequest], + ~.ListTunedModelsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_tuned_models' not in self._stubs: + self._stubs['list_tuned_models'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/ListTunedModels', + request_serializer=model_service.ListTunedModelsRequest.serialize, + response_deserializer=model_service.ListTunedModelsResponse.deserialize, + ) + return self._stubs['list_tuned_models'] + + @property + def create_tuned_model(self) -> Callable[ + [model_service.CreateTunedModelRequest], + operations_pb2.Operation]: + r"""Return a callable for the create tuned model method over gRPC. + + Creates a tuned model. Intermediate tuning progress (if any) is + accessed through the [google.longrunning.Operations] service. + + Status and results can be accessed through the Operations + service. Example: GET + /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + Returns: + Callable[[~.CreateTunedModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_tuned_model' not in self._stubs: + self._stubs['create_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/CreateTunedModel', + request_serializer=model_service.CreateTunedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs['create_tuned_model'] + + @property + def update_tuned_model(self) -> Callable[ + [model_service.UpdateTunedModelRequest], + gag_tuned_model.TunedModel]: + r"""Return a callable for the update tuned model method over gRPC. + + Updates a tuned model. + + Returns: + Callable[[~.UpdateTunedModelRequest], + ~.TunedModel]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_tuned_model' not in self._stubs: + self._stubs['update_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/UpdateTunedModel', + request_serializer=model_service.UpdateTunedModelRequest.serialize, + response_deserializer=gag_tuned_model.TunedModel.deserialize, + ) + return self._stubs['update_tuned_model'] + + @property + def delete_tuned_model(self) -> Callable[ + [model_service.DeleteTunedModelRequest], + empty_pb2.Empty]: + r"""Return a callable for the delete tuned model method over gRPC. + + Deletes a tuned model. + + Returns: + Callable[[~.DeleteTunedModelRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_tuned_model' not in self._stubs: + self._stubs['delete_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/DeleteTunedModel', + request_serializer=model_service.DeleteTunedModelRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs['delete_tuned_model'] + + def close(self): + self.grpc_channel.close() + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ( + 'ModelServiceGrpcTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..c8426f6c3910 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.api_core import operations_v1 +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import ModelServiceGrpcTransport + + +class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): + """gRPC AsyncIO backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Quick check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + Awaitable[model.Model]]: + r"""Return a callable for the get model method over gRPC. + + Gets information about a specific Model. + + Returns: + Callable[[~.GetModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_model' not in self._stubs: + self._stubs['get_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/GetModel', + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs['get_model'] + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + Awaitable[model_service.ListModelsResponse]]: + r"""Return a callable for the list models method over gRPC. + + Lists models available through the API. + + Returns: + Callable[[~.ListModelsRequest], + Awaitable[~.ListModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_models' not in self._stubs: + self._stubs['list_models'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/ListModels', + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs['list_models'] + + @property + def get_tuned_model(self) -> Callable[ + [model_service.GetTunedModelRequest], + Awaitable[tuned_model.TunedModel]]: + r"""Return a callable for the get tuned model method over gRPC. + + Gets information about a specific TunedModel. + + Returns: + Callable[[~.GetTunedModelRequest], + Awaitable[~.TunedModel]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_tuned_model' not in self._stubs: + self._stubs['get_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/GetTunedModel', + request_serializer=model_service.GetTunedModelRequest.serialize, + response_deserializer=tuned_model.TunedModel.deserialize, + ) + return self._stubs['get_tuned_model'] + + @property + def list_tuned_models(self) -> Callable[ + [model_service.ListTunedModelsRequest], + Awaitable[model_service.ListTunedModelsResponse]]: + r"""Return a callable for the list tuned models method over gRPC. + + Lists tuned models owned by the user. + + Returns: + Callable[[~.ListTunedModelsRequest], + Awaitable[~.ListTunedModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_tuned_models' not in self._stubs: + self._stubs['list_tuned_models'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/ListTunedModels', + request_serializer=model_service.ListTunedModelsRequest.serialize, + response_deserializer=model_service.ListTunedModelsResponse.deserialize, + ) + return self._stubs['list_tuned_models'] + + @property + def create_tuned_model(self) -> Callable[ + [model_service.CreateTunedModelRequest], + Awaitable[operations_pb2.Operation]]: + r"""Return a callable for the create tuned model method over gRPC. + + Creates a tuned model. Intermediate tuning progress (if any) is + accessed through the [google.longrunning.Operations] service. + + Status and results can be accessed through the Operations + service. Example: GET + /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + Returns: + Callable[[~.CreateTunedModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_tuned_model' not in self._stubs: + self._stubs['create_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/CreateTunedModel', + request_serializer=model_service.CreateTunedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs['create_tuned_model'] + + @property + def update_tuned_model(self) -> Callable[ + [model_service.UpdateTunedModelRequest], + Awaitable[gag_tuned_model.TunedModel]]: + r"""Return a callable for the update tuned model method over gRPC. + + Updates a tuned model. + + Returns: + Callable[[~.UpdateTunedModelRequest], + Awaitable[~.TunedModel]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_tuned_model' not in self._stubs: + self._stubs['update_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/UpdateTunedModel', + request_serializer=model_service.UpdateTunedModelRequest.serialize, + response_deserializer=gag_tuned_model.TunedModel.deserialize, + ) + return self._stubs['update_tuned_model'] + + @property + def delete_tuned_model(self) -> Callable[ + [model_service.DeleteTunedModelRequest], + Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the delete tuned model method over gRPC. + + Deletes a tuned model. + + Returns: + Callable[[~.DeleteTunedModelRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_tuned_model' not in self._stubs: + self._stubs['delete_tuned_model'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.ModelService/DeleteTunedModel', + request_serializer=model_service.DeleteTunedModelRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs['delete_tuned_model'] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ( + 'ModelServiceGrpcAsyncIOTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py new file mode 100644 index 000000000000..9d43aea96036 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py @@ -0,0 +1,972 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from google.api_core import operations_v1 +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.protobuf import empty_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore + +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class ModelServiceRestInterceptor: + """Interceptor for ModelService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the ModelServiceRestTransport. + + .. code-block:: python + class MyCustomModelServiceInterceptor(ModelServiceRestInterceptor): + def pre_create_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_tuned_model(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_model(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_get_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_tuned_model(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_models(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_models(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_tuned_models(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_tuned_models(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_tuned_model(self, response): + logging.log(f"Received response: {response}") + return response + + transport = ModelServiceRestTransport(interceptor=MyCustomModelServiceInterceptor()) + client = ModelServiceClient(transport=transport) + + + """ + def pre_create_tuned_model(self, request: model_service.CreateTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.CreateTunedModelRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for create_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_create_tuned_model(self, response: operations_pb2.Operation) -> operations_pb2.Operation: + """Post-rpc interceptor for create_tuned_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + def pre_delete_tuned_model(self, request: model_service.DeleteTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.DeleteTunedModelRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for delete_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def pre_get_model(self, request: model_service.GetModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.GetModelRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_get_model(self, response: model.Model) -> model.Model: + """Post-rpc interceptor for get_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + def pre_get_tuned_model(self, request: model_service.GetTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.GetTunedModelRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_get_tuned_model(self, response: tuned_model.TunedModel) -> tuned_model.TunedModel: + """Post-rpc interceptor for get_tuned_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + def pre_list_models(self, request: model_service.ListModelsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.ListModelsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_models + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_list_models(self, response: model_service.ListModelsResponse) -> model_service.ListModelsResponse: + """Post-rpc interceptor for list_models + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + def pre_list_tuned_models(self, request: model_service.ListTunedModelsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.ListTunedModelsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_tuned_models + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_list_tuned_models(self, response: model_service.ListTunedModelsResponse) -> model_service.ListTunedModelsResponse: + """Post-rpc interceptor for list_tuned_models + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + def pre_update_tuned_model(self, request: model_service.UpdateTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.UpdateTunedModelRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for update_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_update_tuned_model(self, response: gag_tuned_model.TunedModel) -> gag_tuned_model.TunedModel: + """Post-rpc interceptor for update_tuned_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class ModelServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: ModelServiceRestInterceptor + + +class ModelServiceRestTransport(ModelServiceTransport): + """REST backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[ + ], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = 'https', + interceptor: Optional[ModelServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST) + self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or ModelServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + @property + def operations_client(self) -> operations_v1.AbstractOperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Only create a new client if we do not already have one. + if self._operations_client is None: + http_options: Dict[str, List[Dict[str, str]]] = { + } + + rest_transport = operations_v1.OperationsRestTransport( + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1beta3") + + self._operations_client = operations_v1.AbstractOperationsClient(transport=rest_transport) + + # Return the client from cache. + return self._operations_client + + class _CreateTunedModel(ModelServiceRestStub): + def __hash__(self): + return hash("CreateTunedModel") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: model_service.CreateTunedModelRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> operations_pb2.Operation: + r"""Call the create tuned model method over HTTP. + + Args: + request (~.model_service.CreateTunedModelRequest): + The request object. Request to create a TunedModel. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/tunedModels', + 'body': 'tuned_model', + }, + ] + request, metadata = self._interceptor.pre_create_tuned_model(request, metadata) + pb_request = model_service.CreateTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_tuned_model(resp) + return resp + + class _DeleteTunedModel(ModelServiceRestStub): + def __hash__(self): + return hash("DeleteTunedModel") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: model_service.DeleteTunedModelRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ): + r"""Call the delete tuned model method over HTTP. + + Args: + request (~.model_service.DeleteTunedModelRequest): + The request object. Request to delete a TunedModel. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'delete', + 'uri': '/v1beta3/{name=tunedModels/*}', + }, + ] + request, metadata = self._interceptor.pre_delete_tuned_model(request, metadata) + pb_request = model_service.DeleteTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetModel(ModelServiceRestStub): + def __hash__(self): + return hash("GetModel") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: model_service.GetModelRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> model.Model: + r"""Call the get model method over HTTP. + + Args: + request (~.model_service.GetModelRequest): + The request object. Request for getting information about + a specific Model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model.Model: + Information about a Generative + Language Model. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta3/{name=models/*}', + }, + ] + request, metadata = self._interceptor.pre_get_model(request, metadata) + pb_request = model_service.GetModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model.Model() + pb_resp = model.Model.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_model(resp) + return resp + + class _GetTunedModel(ModelServiceRestStub): + def __hash__(self): + return hash("GetTunedModel") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: model_service.GetTunedModelRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> tuned_model.TunedModel: + r"""Call the get tuned model method over HTTP. + + Args: + request (~.model_service.GetTunedModelRequest): + The request object. Request for getting information about + a specific Model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.tuned_model.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta3/{name=tunedModels/*}', + }, + ] + request, metadata = self._interceptor.pre_get_tuned_model(request, metadata) + pb_request = model_service.GetTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = tuned_model.TunedModel() + pb_resp = tuned_model.TunedModel.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_tuned_model(resp) + return resp + + class _ListModels(ModelServiceRestStub): + def __hash__(self): + return hash("ListModels") + + def __call__(self, + request: model_service.ListModelsRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> model_service.ListModelsResponse: + r"""Call the list models method over HTTP. + + Args: + request (~.model_service.ListModelsRequest): + The request object. Request for listing all Models. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model_service.ListModelsResponse: + Response from ``ListModel`` containing a paginated list + of Models. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta3/models', + }, + ] + request, metadata = self._interceptor.pre_list_models(request, metadata) + pb_request = model_service.ListModelsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model_service.ListModelsResponse() + pb_resp = model_service.ListModelsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_models(resp) + return resp + + class _ListTunedModels(ModelServiceRestStub): + def __hash__(self): + return hash("ListTunedModels") + + def __call__(self, + request: model_service.ListTunedModelsRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> model_service.ListTunedModelsResponse: + r"""Call the list tuned models method over HTTP. + + Args: + request (~.model_service.ListTunedModelsRequest): + The request object. Request for listing TunedModels. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.model_service.ListTunedModelsResponse: + Response from ``ListTunedModels`` containing a paginated + list of Models. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta3/tunedModels', + }, + ] + request, metadata = self._interceptor.pre_list_tuned_models(request, metadata) + pb_request = model_service.ListTunedModelsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model_service.ListTunedModelsResponse() + pb_resp = model_service.ListTunedModelsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_tuned_models(resp) + return resp + + class _UpdateTunedModel(ModelServiceRestStub): + def __hash__(self): + return hash("UpdateTunedModel") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask" : {}, } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: model_service.UpdateTunedModelRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> gag_tuned_model.TunedModel: + r"""Call the update tuned model method over HTTP. + + Args: + request (~.model_service.UpdateTunedModelRequest): + The request object. Request to update a TunedModel. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gag_tuned_model.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'patch', + 'uri': '/v1beta3/{tuned_model.name=tunedModels/*}', + 'body': 'tuned_model', + }, + ] + request, metadata = self._interceptor.pre_update_tuned_model(request, metadata) + pb_request = model_service.UpdateTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_tuned_model.TunedModel() + pb_resp = gag_tuned_model.TunedModel.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_tuned_model(resp) + return resp + + @property + def create_tuned_model(self) -> Callable[ + [model_service.CreateTunedModelRequest], + operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_tuned_model(self) -> Callable[ + [model_service.DeleteTunedModelRequest], + empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_model(self) -> Callable[ + [model_service.GetModelRequest], + model.Model]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_tuned_model(self) -> Callable[ + [model_service.GetTunedModelRequest], + tuned_model.TunedModel]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_models(self) -> Callable[ + [model_service.ListModelsRequest], + model_service.ListModelsResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListModels(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_tuned_models(self) -> Callable[ + [model_service.ListTunedModelsRequest], + model_service.ListTunedModelsResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListTunedModels(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_tuned_model(self) -> Callable[ + [model_service.UpdateTunedModelRequest], + gag_tuned_model.TunedModel]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__=( + 'ModelServiceRestTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py new file mode 100644 index 000000000000..eb61a596594a --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import PermissionServiceClient +from .async_client import PermissionServiceAsyncClient + +__all__ = ( + 'PermissionServiceClient', + 'PermissionServiceAsyncClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py new file mode 100644 index 000000000000..f83af219c903 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py @@ -0,0 +1,876 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.services.permission_service import pagers +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from .transports.base import PermissionServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import PermissionServiceGrpcAsyncIOTransport +from .client import PermissionServiceClient + + +class PermissionServiceAsyncClient: + """Provides methods for managing permissions to PaLM API + resources. + """ + + _client: PermissionServiceClient + + DEFAULT_ENDPOINT = PermissionServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PermissionServiceClient.DEFAULT_MTLS_ENDPOINT + + permission_path = staticmethod(PermissionServiceClient.permission_path) + parse_permission_path = staticmethod(PermissionServiceClient.parse_permission_path) + tuned_model_path = staticmethod(PermissionServiceClient.tuned_model_path) + parse_tuned_model_path = staticmethod(PermissionServiceClient.parse_tuned_model_path) + common_billing_account_path = staticmethod(PermissionServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(PermissionServiceClient.parse_common_billing_account_path) + common_folder_path = staticmethod(PermissionServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(PermissionServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(PermissionServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(PermissionServiceClient.parse_common_organization_path) + common_project_path = staticmethod(PermissionServiceClient.common_project_path) + parse_common_project_path = staticmethod(PermissionServiceClient.parse_common_project_path) + common_location_path = staticmethod(PermissionServiceClient.common_location_path) + parse_common_location_path = staticmethod(PermissionServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceAsyncClient: The constructed client. + """ + return PermissionServiceClient.from_service_account_info.__func__(PermissionServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceAsyncClient: The constructed client. + """ + return PermissionServiceClient.from_service_account_file.__func__(PermissionServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return PermissionServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> PermissionServiceTransport: + """Returns the transport used by the client instance. + + Returns: + PermissionServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(PermissionServiceClient).get_transport_class, type(PermissionServiceClient)) + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, PermissionServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the permission service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.PermissionServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = PermissionServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def create_permission(self, + request: Optional[Union[permission_service.CreatePermissionRequest, dict]] = None, + *, + parent: Optional[str] = None, + permission: Optional[gag_permission.Permission] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: + r"""Create a permission to a specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_create_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = await client.create_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.CreatePermissionRequest, dict]]): + The request object. Request to create a ``Permission``. + parent (:class:`str`): + Required. The parent resource of the ``Permission``. + Format: tunedModels/{tuned_model} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + permission (:class:`google.ai.generativelanguage_v1beta3.types.Permission`): + Required. The permission to create. + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, permission]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = permission_service.CreatePermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if permission is not None: + request.permission = permission + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_permission, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("parent", request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_permission(self, + request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission.Permission: + r"""Gets information about a specific Permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_get_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = await client.get_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.GetPermissionRequest, dict]]): + The request object. Request for getting information about a specific + ``Permission``. + name (:class:`str`): + Required. The resource name of the permission. + + Format: + ``tunedModels/{tuned_model}permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = permission_service.GetPermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_permission, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_permissions(self, + request: Optional[Union[permission_service.ListPermissionsRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListPermissionsAsyncPager: + r"""Lists permissions for the specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_list_permissions(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.ListPermissionsRequest, dict]]): + The request object. Request for listing permissions. + parent (:class:`str`): + Required. The parent resource of the permissions. + Format: tunedModels/{tuned_model} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.services.permission_service.pagers.ListPermissionsAsyncPager: + Response from ListPermissions containing a paginated list of + permissions. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = permission_service.ListPermissionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_permissions, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("parent", request.parent), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListPermissionsAsyncPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_permission(self, + request: Optional[Union[permission_service.UpdatePermissionRequest, dict]] = None, + *, + permission: Optional[gag_permission.Permission] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: + r"""Updates the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_update_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.UpdatePermissionRequest( + ) + + # Make the request + response = await client.update_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.UpdatePermissionRequest, dict]]): + The request object. Request to update the ``Permission``. + permission (:class:`google.ai.generativelanguage_v1beta3.types.Permission`): + Required. The permission to update. + + The permission's ``name`` field is used to identify the + permission to update. + + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The list of fields to update. Accepted ones: + + - role (``Permission.role`` field) + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([permission, update_mask]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = permission_service.UpdatePermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if permission is not None: + request.permission = permission + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_permission, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("permission.name", request.permission.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_permission(self, + request: Optional[Union[permission_service.DeletePermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_delete_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + await client.delete_permission(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.DeletePermissionRequest, dict]]): + The request object. Request to delete the ``Permission``. + name (:class:`str`): + Required. The resource name of the permission. Format: + ``tunedModels/{tuned_model}/permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = permission_service.DeletePermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_permission, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def transfer_ownership(self, + request: Optional[Union[permission_service.TransferOwnershipRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission_service.TransferOwnershipResponse: + r"""Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = await client.transfer_ownership(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.TransferOwnershipRequest, dict]]): + The request object. Request to transfer the ownership of + the tuned model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.TransferOwnershipResponse: + Response from TransferOwnership. + """ + # Create or coerce a protobuf request object. + request = permission_service.TransferOwnershipRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.transfer_ownership, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "PermissionServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "PermissionServiceAsyncClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/client.py new file mode 100644 index 000000000000..e25e91f72b5f --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/client.py @@ -0,0 +1,1094 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.services.permission_service import pagers +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from .transports.base import PermissionServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import PermissionServiceGrpcTransport +from .transports.grpc_asyncio import PermissionServiceGrpcAsyncIOTransport +from .transports.rest import PermissionServiceRestTransport + + +class PermissionServiceClientMeta(type): + """Metaclass for the PermissionService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[PermissionServiceTransport]] + _transport_registry["grpc"] = PermissionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PermissionServiceGrpcAsyncIOTransport + _transport_registry["rest"] = PermissionServiceRestTransport + + def get_transport_class(cls, + label: Optional[str] = None, + ) -> Type[PermissionServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class PermissionServiceClient(metaclass=PermissionServiceClientMeta): + """Provides methods for managing permissions to PaLM API + resources. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PermissionServiceTransport: + """Returns the transport used by the client instance. + + Returns: + PermissionServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def permission_path(tuned_model: str,permission: str,) -> str: + """Returns a fully-qualified permission string.""" + return "tunedModels/{tuned_model}/permissions/{permission}".format(tuned_model=tuned_model, permission=permission, ) + + @staticmethod + def parse_permission_path(path: str) -> Dict[str,str]: + """Parses a permission path into its component segments.""" + m = re.match(r"^tunedModels/(?P.+?)/permissions/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def tuned_model_path(tuned_model: str,) -> str: + """Returns a fully-qualified tuned_model string.""" + return "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + + @staticmethod + def parse_tuned_model_path(path: str) -> Dict[str,str]: + """Parses a tuned_model path into its component segments.""" + m = re.match(r"^tunedModels/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, PermissionServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the permission service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, PermissionServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError("client_options.api_key and credentials are mutually exclusive") + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, PermissionServiceTransport): + # transport is a PermissionServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError("When providing a transport instance, " + "provide its credentials directly.") + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): + credentials = google.auth._default.get_api_key_credentials(api_key_value) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def create_permission(self, + request: Optional[Union[permission_service.CreatePermissionRequest, dict]] = None, + *, + parent: Optional[str] = None, + permission: Optional[gag_permission.Permission] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: + r"""Create a permission to a specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_create_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = client.create_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.CreatePermissionRequest, dict]): + The request object. Request to create a ``Permission``. + parent (str): + Required. The parent resource of the ``Permission``. + Format: tunedModels/{tuned_model} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + permission (google.ai.generativelanguage_v1beta3.types.Permission): + Required. The permission to create. + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, permission]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a permission_service.CreatePermissionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, permission_service.CreatePermissionRequest): + request = permission_service.CreatePermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if permission is not None: + request.permission = permission + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("parent", request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_permission(self, + request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission.Permission: + r"""Gets information about a specific Permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_get_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = client.get_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.GetPermissionRequest, dict]): + The request object. Request for getting information about a specific + ``Permission``. + name (str): + Required. The resource name of the permission. + + Format: + ``tunedModels/{tuned_model}permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a permission_service.GetPermissionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, permission_service.GetPermissionRequest): + request = permission_service.GetPermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_permissions(self, + request: Optional[Union[permission_service.ListPermissionsRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListPermissionsPager: + r"""Lists permissions for the specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_list_permissions(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.ListPermissionsRequest, dict]): + The request object. Request for listing permissions. + parent (str): + Required. The parent resource of the permissions. + Format: tunedModels/{tuned_model} + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.services.permission_service.pagers.ListPermissionsPager: + Response from ListPermissions containing a paginated list of + permissions. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a permission_service.ListPermissionsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, permission_service.ListPermissionsRequest): + request = permission_service.ListPermissionsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_permissions] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("parent", request.parent), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPermissionsPager( + method=rpc, + request=request, + response=response, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_permission(self, + request: Optional[Union[permission_service.UpdatePermissionRequest, dict]] = None, + *, + permission: Optional[gag_permission.Permission] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: + r"""Updates the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_update_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.UpdatePermissionRequest( + ) + + # Make the request + response = client.update_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.UpdatePermissionRequest, dict]): + The request object. Request to update the ``Permission``. + permission (google.ai.generativelanguage_v1beta3.types.Permission): + Required. The permission to update. + + The permission's ``name`` field is used to identify the + permission to update. + + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Accepted ones: + + - role (``Permission.role`` field) + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([permission, update_mask]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a permission_service.UpdatePermissionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, permission_service.UpdatePermissionRequest): + request = permission_service.UpdatePermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if permission is not None: + request.permission = permission + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("permission.name", request.permission.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_permission(self, + request: Optional[Union[permission_service.DeletePermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Deletes the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_delete_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + client.delete_permission(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.DeletePermissionRequest, dict]): + The request object. Request to delete the ``Permission``. + name (str): + Required. The resource name of the permission. Format: + ``tunedModels/{tuned_model}/permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a permission_service.DeletePermissionRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, permission_service.DeletePermissionRequest): + request = permission_service.DeletePermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def transfer_ownership(self, + request: Optional[Union[permission_service.TransferOwnershipRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission_service.TransferOwnershipResponse: + r"""Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = client.transfer_ownership(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.TransferOwnershipRequest, dict]): + The request object. Request to transfer the ownership of + the tuned model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.TransferOwnershipResponse: + Response from TransferOwnership. + """ + # Create or coerce a protobuf request object. + # Minor optimization to avoid making a copy if the user passes + # in a permission_service.TransferOwnershipRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, permission_service.TransferOwnershipRequest): + request = permission_service.TransferOwnershipRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.transfer_ownership] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "PermissionServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + + + + + + + + + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "PermissionServiceClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py new file mode 100644 index 000000000000..188d4c7963a6 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, AsyncIterator, Awaitable, Callable, Sequence, Tuple, Optional, Iterator + +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission_service + + +class ListPermissionsPager: + """A pager for iterating through ``list_permissions`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta3.types.ListPermissionsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``permissions`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListPermissions`` requests and continue to iterate + through the ``permissions`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta3.types.ListPermissionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., permission_service.ListPermissionsResponse], + request: permission_service.ListPermissionsRequest, + response: permission_service.ListPermissionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta3.types.ListPermissionsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta3.types.ListPermissionsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = permission_service.ListPermissionsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[permission_service.ListPermissionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterator[permission.Permission]: + for page in self.pages: + yield from page.permissions + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + + +class ListPermissionsAsyncPager: + """A pager for iterating through ``list_permissions`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1beta3.types.ListPermissionsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``permissions`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListPermissions`` requests and continue to iterate + through the ``permissions`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1beta3.types.ListPermissionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + def __init__(self, + method: Callable[..., Awaitable[permission_service.ListPermissionsResponse]], + request: permission_service.ListPermissionsRequest, + response: permission_service.ListPermissionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = ()): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1beta3.types.ListPermissionsRequest): + The initial request object. + response (google.ai.generativelanguage_v1beta3.types.ListPermissionsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = permission_service.ListPermissionsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[permission_service.ListPermissionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + def __aiter__(self) -> AsyncIterator[permission.Permission]: + async def async_generator(): + async for page in self.pages: + for response in page.permissions: + yield response + + return async_generator() + + def __repr__(self) -> str: + return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py new file mode 100644 index 000000000000..5232d4043c80 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import PermissionServiceTransport +from .grpc import PermissionServiceGrpcTransport +from .grpc_asyncio import PermissionServiceGrpcAsyncIOTransport +from .rest import PermissionServiceRestTransport +from .rest import PermissionServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[PermissionServiceTransport]] +_transport_registry['grpc'] = PermissionServiceGrpcTransport +_transport_registry['grpc_asyncio'] = PermissionServiceGrpcAsyncIOTransport +_transport_registry['rest'] = PermissionServiceRestTransport + +__all__ = ( + 'PermissionServiceTransport', + 'PermissionServiceGrpcTransport', + 'PermissionServiceGrpcAsyncIOTransport', + 'PermissionServiceRestTransport', + 'PermissionServiceRestInterceptor', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py new file mode 100644 index 000000000000..d0d736d33e11 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +class PermissionServiceTransport(abc.ABC): + """Abstract transport class for PermissionService.""" + + AUTH_SCOPES = ( + ) + + DEFAULT_HOST: str = 'generativelanguage.googleapis.com' + def __init__( + self, *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, + **scopes_kwargs, + quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_permission: gapic_v1.method.wrap_method( + self.create_permission, + default_timeout=None, + client_info=client_info, + ), + self.get_permission: gapic_v1.method.wrap_method( + self.get_permission, + default_timeout=None, + client_info=client_info, + ), + self.list_permissions: gapic_v1.method.wrap_method( + self.list_permissions, + default_timeout=None, + client_info=client_info, + ), + self.update_permission: gapic_v1.method.wrap_method( + self.update_permission, + default_timeout=None, + client_info=client_info, + ), + self.delete_permission: gapic_v1.method.wrap_method( + self.delete_permission, + default_timeout=None, + client_info=client_info, + ), + self.transfer_ownership: gapic_v1.method.wrap_method( + self.transfer_ownership, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def create_permission(self) -> Callable[ + [permission_service.CreatePermissionRequest], + Union[ + gag_permission.Permission, + Awaitable[gag_permission.Permission] + ]]: + raise NotImplementedError() + + @property + def get_permission(self) -> Callable[ + [permission_service.GetPermissionRequest], + Union[ + permission.Permission, + Awaitable[permission.Permission] + ]]: + raise NotImplementedError() + + @property + def list_permissions(self) -> Callable[ + [permission_service.ListPermissionsRequest], + Union[ + permission_service.ListPermissionsResponse, + Awaitable[permission_service.ListPermissionsResponse] + ]]: + raise NotImplementedError() + + @property + def update_permission(self) -> Callable[ + [permission_service.UpdatePermissionRequest], + Union[ + gag_permission.Permission, + Awaitable[gag_permission.Permission] + ]]: + raise NotImplementedError() + + @property + def delete_permission(self) -> Callable[ + [permission_service.DeletePermissionRequest], + Union[ + empty_pb2.Empty, + Awaitable[empty_pb2.Empty] + ]]: + raise NotImplementedError() + + @property + def transfer_ownership(self) -> Callable[ + [permission_service.TransferOwnershipRequest], + Union[ + permission_service.TransferOwnershipResponse, + Awaitable[permission_service.TransferOwnershipResponse] + ]]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ( + 'PermissionServiceTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py new file mode 100644 index 000000000000..b7d0a5160c35 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py @@ -0,0 +1,402 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from .base import PermissionServiceTransport, DEFAULT_CLIENT_INFO + + +class PermissionServiceGrpcTransport(PermissionServiceTransport): + """gRPC backend transport for PermissionService. + + Provides methods for managing permissions to PaLM API + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def create_permission(self) -> Callable[ + [permission_service.CreatePermissionRequest], + gag_permission.Permission]: + r"""Return a callable for the create permission method over gRPC. + + Create a permission to a specific resource. + + Returns: + Callable[[~.CreatePermissionRequest], + ~.Permission]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_permission' not in self._stubs: + self._stubs['create_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/CreatePermission', + request_serializer=permission_service.CreatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs['create_permission'] + + @property + def get_permission(self) -> Callable[ + [permission_service.GetPermissionRequest], + permission.Permission]: + r"""Return a callable for the get permission method over gRPC. + + Gets information about a specific Permission. + + Returns: + Callable[[~.GetPermissionRequest], + ~.Permission]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_permission' not in self._stubs: + self._stubs['get_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/GetPermission', + request_serializer=permission_service.GetPermissionRequest.serialize, + response_deserializer=permission.Permission.deserialize, + ) + return self._stubs['get_permission'] + + @property + def list_permissions(self) -> Callable[ + [permission_service.ListPermissionsRequest], + permission_service.ListPermissionsResponse]: + r"""Return a callable for the list permissions method over gRPC. + + Lists permissions for the specific resource. + + Returns: + Callable[[~.ListPermissionsRequest], + ~.ListPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_permissions' not in self._stubs: + self._stubs['list_permissions'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/ListPermissions', + request_serializer=permission_service.ListPermissionsRequest.serialize, + response_deserializer=permission_service.ListPermissionsResponse.deserialize, + ) + return self._stubs['list_permissions'] + + @property + def update_permission(self) -> Callable[ + [permission_service.UpdatePermissionRequest], + gag_permission.Permission]: + r"""Return a callable for the update permission method over gRPC. + + Updates the permission. + + Returns: + Callable[[~.UpdatePermissionRequest], + ~.Permission]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_permission' not in self._stubs: + self._stubs['update_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/UpdatePermission', + request_serializer=permission_service.UpdatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs['update_permission'] + + @property + def delete_permission(self) -> Callable[ + [permission_service.DeletePermissionRequest], + empty_pb2.Empty]: + r"""Return a callable for the delete permission method over gRPC. + + Deletes the permission. + + Returns: + Callable[[~.DeletePermissionRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_permission' not in self._stubs: + self._stubs['delete_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/DeletePermission', + request_serializer=permission_service.DeletePermissionRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs['delete_permission'] + + @property + def transfer_ownership(self) -> Callable[ + [permission_service.TransferOwnershipRequest], + permission_service.TransferOwnershipResponse]: + r"""Return a callable for the transfer ownership method over gRPC. + + Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + Returns: + Callable[[~.TransferOwnershipRequest], + ~.TransferOwnershipResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'transfer_ownership' not in self._stubs: + self._stubs['transfer_ownership'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/TransferOwnership', + request_serializer=permission_service.TransferOwnershipRequest.serialize, + response_deserializer=permission_service.TransferOwnershipResponse.deserialize, + ) + return self._stubs['transfer_ownership'] + + def close(self): + self.grpc_channel.close() + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ( + 'PermissionServiceGrpcTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..3f07d3cf3bd1 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py @@ -0,0 +1,401 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from .base import PermissionServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import PermissionServiceGrpcTransport + + +class PermissionServiceGrpcAsyncIOTransport(PermissionServiceTransport): + """gRPC AsyncIO backend transport for PermissionService. + + Provides methods for managing permissions to PaLM API + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def create_permission(self) -> Callable[ + [permission_service.CreatePermissionRequest], + Awaitable[gag_permission.Permission]]: + r"""Return a callable for the create permission method over gRPC. + + Create a permission to a specific resource. + + Returns: + Callable[[~.CreatePermissionRequest], + Awaitable[~.Permission]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'create_permission' not in self._stubs: + self._stubs['create_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/CreatePermission', + request_serializer=permission_service.CreatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs['create_permission'] + + @property + def get_permission(self) -> Callable[ + [permission_service.GetPermissionRequest], + Awaitable[permission.Permission]]: + r"""Return a callable for the get permission method over gRPC. + + Gets information about a specific Permission. + + Returns: + Callable[[~.GetPermissionRequest], + Awaitable[~.Permission]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'get_permission' not in self._stubs: + self._stubs['get_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/GetPermission', + request_serializer=permission_service.GetPermissionRequest.serialize, + response_deserializer=permission.Permission.deserialize, + ) + return self._stubs['get_permission'] + + @property + def list_permissions(self) -> Callable[ + [permission_service.ListPermissionsRequest], + Awaitable[permission_service.ListPermissionsResponse]]: + r"""Return a callable for the list permissions method over gRPC. + + Lists permissions for the specific resource. + + Returns: + Callable[[~.ListPermissionsRequest], + Awaitable[~.ListPermissionsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'list_permissions' not in self._stubs: + self._stubs['list_permissions'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/ListPermissions', + request_serializer=permission_service.ListPermissionsRequest.serialize, + response_deserializer=permission_service.ListPermissionsResponse.deserialize, + ) + return self._stubs['list_permissions'] + + @property + def update_permission(self) -> Callable[ + [permission_service.UpdatePermissionRequest], + Awaitable[gag_permission.Permission]]: + r"""Return a callable for the update permission method over gRPC. + + Updates the permission. + + Returns: + Callable[[~.UpdatePermissionRequest], + Awaitable[~.Permission]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'update_permission' not in self._stubs: + self._stubs['update_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/UpdatePermission', + request_serializer=permission_service.UpdatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs['update_permission'] + + @property + def delete_permission(self) -> Callable[ + [permission_service.DeletePermissionRequest], + Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the delete permission method over gRPC. + + Deletes the permission. + + Returns: + Callable[[~.DeletePermissionRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'delete_permission' not in self._stubs: + self._stubs['delete_permission'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/DeletePermission', + request_serializer=permission_service.DeletePermissionRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs['delete_permission'] + + @property + def transfer_ownership(self) -> Callable[ + [permission_service.TransferOwnershipRequest], + Awaitable[permission_service.TransferOwnershipResponse]]: + r"""Return a callable for the transfer ownership method over gRPC. + + Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + Returns: + Callable[[~.TransferOwnershipRequest], + Awaitable[~.TransferOwnershipResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'transfer_ownership' not in self._stubs: + self._stubs['transfer_ownership'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.PermissionService/TransferOwnership', + request_serializer=permission_service.TransferOwnershipRequest.serialize, + response_deserializer=permission_service.TransferOwnershipResponse.deserialize, + ) + return self._stubs['transfer_ownership'] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ( + 'PermissionServiceGrpcAsyncIOTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py new file mode 100644 index 000000000000..913244341d45 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py @@ -0,0 +1,919 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from google.protobuf import empty_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore + +from .base import PermissionServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class PermissionServiceRestInterceptor: + """Interceptor for PermissionService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the PermissionServiceRestTransport. + + .. code-block:: python + class MyCustomPermissionServiceInterceptor(PermissionServiceRestInterceptor): + def pre_create_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_permission(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_permission(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_permissions(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_permissions(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_transfer_ownership(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_transfer_ownership(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_permission(self, response): + logging.log(f"Received response: {response}") + return response + + transport = PermissionServiceRestTransport(interceptor=MyCustomPermissionServiceInterceptor()) + client = PermissionServiceClient(transport=transport) + + + """ + def pre_create_permission(self, request: permission_service.CreatePermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.CreatePermissionRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for create_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_create_permission(self, response: gag_permission.Permission) -> gag_permission.Permission: + """Post-rpc interceptor for create_permission + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + def pre_delete_permission(self, request: permission_service.DeletePermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.DeletePermissionRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for delete_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def pre_get_permission(self, request: permission_service.GetPermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.GetPermissionRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for get_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_get_permission(self, response: permission.Permission) -> permission.Permission: + """Post-rpc interceptor for get_permission + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + def pre_list_permissions(self, request: permission_service.ListPermissionsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.ListPermissionsRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for list_permissions + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_list_permissions(self, response: permission_service.ListPermissionsResponse) -> permission_service.ListPermissionsResponse: + """Post-rpc interceptor for list_permissions + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + def pre_transfer_ownership(self, request: permission_service.TransferOwnershipRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.TransferOwnershipRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for transfer_ownership + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_transfer_ownership(self, response: permission_service.TransferOwnershipResponse) -> permission_service.TransferOwnershipResponse: + """Post-rpc interceptor for transfer_ownership + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + def pre_update_permission(self, request: permission_service.UpdatePermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.UpdatePermissionRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for update_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_update_permission(self, response: gag_permission.Permission) -> gag_permission.Permission: + """Post-rpc interceptor for update_permission + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class PermissionServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: PermissionServiceRestInterceptor + + +class PermissionServiceRestTransport(PermissionServiceTransport): + """REST backend transport for PermissionService. + + Provides methods for managing permissions to PaLM API + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[ + ], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = 'https', + interceptor: Optional[PermissionServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or PermissionServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CreatePermission(PermissionServiceRestStub): + def __hash__(self): + return hash("CreatePermission") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: permission_service.CreatePermissionRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> gag_permission.Permission: + r"""Call the create permission method over HTTP. + + Args: + request (~.permission_service.CreatePermissionRequest): + The request object. Request to create a ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gag_permission.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{parent=tunedModels/*}/permissions', + 'body': 'permission', + }, + ] + request, metadata = self._interceptor.pre_create_permission(request, metadata) + pb_request = permission_service.CreatePermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_permission.Permission() + pb_resp = gag_permission.Permission.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_create_permission(resp) + return resp + + class _DeletePermission(PermissionServiceRestStub): + def __hash__(self): + return hash("DeletePermission") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: permission_service.DeletePermissionRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ): + r"""Call the delete permission method over HTTP. + + Args: + request (~.permission_service.DeletePermissionRequest): + The request object. Request to delete the ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'delete', + 'uri': '/v1beta3/{name=tunedModels/*/permissions/*}', + }, + ] + request, metadata = self._interceptor.pre_delete_permission(request, metadata) + pb_request = permission_service.DeletePermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetPermission(PermissionServiceRestStub): + def __hash__(self): + return hash("GetPermission") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: permission_service.GetPermissionRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> permission.Permission: + r"""Call the get permission method over HTTP. + + Args: + request (~.permission_service.GetPermissionRequest): + The request object. Request for getting information about a specific + ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.permission.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta3/{name=tunedModels/*/permissions/*}', + }, + ] + request, metadata = self._interceptor.pre_get_permission(request, metadata) + pb_request = permission_service.GetPermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = permission.Permission() + pb_resp = permission.Permission.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_get_permission(resp) + return resp + + class _ListPermissions(PermissionServiceRestStub): + def __hash__(self): + return hash("ListPermissions") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: permission_service.ListPermissionsRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> permission_service.ListPermissionsResponse: + r"""Call the list permissions method over HTTP. + + Args: + request (~.permission_service.ListPermissionsRequest): + The request object. Request for listing permissions. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.permission_service.ListPermissionsResponse: + Response from ``ListPermissions`` containing a paginated + list of permissions. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'get', + 'uri': '/v1beta3/{parent=tunedModels/*}/permissions', + }, + ] + request, metadata = self._interceptor.pre_list_permissions(request, metadata) + pb_request = permission_service.ListPermissionsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = permission_service.ListPermissionsResponse() + pb_resp = permission_service.ListPermissionsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_list_permissions(resp) + return resp + + class _TransferOwnership(PermissionServiceRestStub): + def __hash__(self): + return hash("TransferOwnership") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: permission_service.TransferOwnershipRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> permission_service.TransferOwnershipResponse: + r"""Call the transfer ownership method over HTTP. + + Args: + request (~.permission_service.TransferOwnershipRequest): + The request object. Request to transfer the ownership of + the tuned model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.permission_service.TransferOwnershipResponse: + Response from ``TransferOwnership``. + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{name=tunedModels/*}:transferOwnership', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_transfer_ownership(request, metadata) + pb_request = permission_service.TransferOwnershipRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = permission_service.TransferOwnershipResponse() + pb_resp = permission_service.TransferOwnershipResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_transfer_ownership(resp) + return resp + + class _UpdatePermission(PermissionServiceRestStub): + def __hash__(self): + return hash("UpdatePermission") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask" : {}, } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: permission_service.UpdatePermissionRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> gag_permission.Permission: + r"""Call the update permission method over HTTP. + + Args: + request (~.permission_service.UpdatePermissionRequest): + The request object. Request to update the ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.gag_permission.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, file). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'patch', + 'uri': '/v1beta3/{permission.name=tunedModels/*/permissions/*}', + 'body': 'permission', + }, + ] + request, metadata = self._interceptor.pre_update_permission(request, metadata) + pb_request = permission_service.UpdatePermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_permission.Permission() + pb_resp = gag_permission.Permission.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_update_permission(resp) + return resp + + @property + def create_permission(self) -> Callable[ + [permission_service.CreatePermissionRequest], + gag_permission.Permission]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreatePermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_permission(self) -> Callable[ + [permission_service.DeletePermissionRequest], + empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeletePermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_permission(self) -> Callable[ + [permission_service.GetPermissionRequest], + permission.Permission]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetPermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_permissions(self) -> Callable[ + [permission_service.ListPermissionsRequest], + permission_service.ListPermissionsResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListPermissions(self._session, self._host, self._interceptor) # type: ignore + + @property + def transfer_ownership(self) -> Callable[ + [permission_service.TransferOwnershipRequest], + permission_service.TransferOwnershipResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._TransferOwnership(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_permission(self) -> Callable[ + [permission_service.UpdatePermissionRequest], + gag_permission.Permission]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdatePermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__=( + 'PermissionServiceRestTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py new file mode 100644 index 000000000000..f167a9c3175d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .client import TextServiceClient +from .async_client import TextServiceAsyncClient + +__all__ = ( + 'TextServiceClient', + 'TextServiceAsyncClient', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py new file mode 100644 index 000000000000..f1be99ee3ab2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py @@ -0,0 +1,760 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import functools +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core.client_options import ClientOptions +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.types import safety +from google.ai.generativelanguage_v1beta3.types import text_service +from google.longrunning import operations_pb2 # type: ignore +from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .client import TextServiceClient + + +class TextServiceAsyncClient: + """API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + """ + + _client: TextServiceClient + + DEFAULT_ENDPOINT = TextServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = TextServiceClient.DEFAULT_MTLS_ENDPOINT + + model_path = staticmethod(TextServiceClient.model_path) + parse_model_path = staticmethod(TextServiceClient.parse_model_path) + common_billing_account_path = staticmethod(TextServiceClient.common_billing_account_path) + parse_common_billing_account_path = staticmethod(TextServiceClient.parse_common_billing_account_path) + common_folder_path = staticmethod(TextServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(TextServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(TextServiceClient.common_organization_path) + parse_common_organization_path = staticmethod(TextServiceClient.parse_common_organization_path) + common_project_path = staticmethod(TextServiceClient.common_project_path) + parse_common_project_path = staticmethod(TextServiceClient.parse_common_project_path) + common_location_path = staticmethod(TextServiceClient.common_location_path) + parse_common_location_path = staticmethod(TextServiceClient.parse_common_location_path) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceAsyncClient: The constructed client. + """ + return TextServiceClient.from_service_account_info.__func__(TextServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceAsyncClient: The constructed client. + """ + return TextServiceClient.from_service_account_file.__func__(TextServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return TextServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> TextServiceTransport: + """Returns the transport used by the client instance. + + Returns: + TextServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial(type(TextServiceClient).get_transport_class, type(TextServiceClient)) + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, TextServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the text service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.TextServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = TextServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + + ) + + async def generate_text(self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.GenerateTextResponse: + r"""Generates a response from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.GenerateTextRequest, dict]]): + The request object. Request to generate a text completion + response from the model. + model (:class:`str`): + Required. The name of the ``Model`` or ``TunedModel`` to + use for generating the completion. Examples: + models/text-bison-001 + tunedModels/sentence-translator-u3b7m + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1beta3.types.TextPrompt`): + Required. The free-form input text + given to the model as a prompt. + Given a prompt, the model will generate + a TextCompletion response it predicts as + the completion of the input text. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (:class:`float`): + Optional. Controls the randomness of the output. Note: + The default value varies by model, see the + ``Model.temperature`` attribute of the ``Model`` + returned the ``getModel`` function. + + Values can range from [0.0,1.0], inclusive. A value + closer to 1.0 will produce responses that are more + varied and creative, while a value closer to 0.0 will + typically result in more straightforward responses from + the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (:class:`int`): + Optional. Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, + this will default to 1. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + max_output_tokens (:class:`int`): + Optional. The maximum number of tokens to include in a + candidate. + + If unset, this will default to output_token_limit + specified in the ``Model`` specification. + + This corresponds to the ``max_output_tokens`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (:class:`float`): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities + so that only the most likely tokens are considered. + Top-k sampling directly limits the maximum number of + tokens to consider, while Nucleus sampling limits number + of tokens based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (:class:`int`): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = text_service.GenerateTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if max_output_tokens is not None: + request.max_output_tokens = max_output_tokens + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.generate_text, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def embed_text(self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.EmbedTextResponse: + r"""Generates an embedding from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = await client.embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.EmbedTextRequest, dict]]): + The request object. Request to get a text embedding from + the model. + model (:class:`str`): + Required. The model name to use with + the format model=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + text (:class:`str`): + Required. The free-form input text + that the model will turn into an + embedding. + + This corresponds to the ``text`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, text]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = text_service.EmbedTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if text is not None: + request.text = text + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.embed_text, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_embed_text(self, + request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + texts: Optional[MutableSequence[str]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.BatchEmbedTextResponse: + r"""Generates multiple embeddings from the model given + input text in a synchronous call. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.BatchEmbedTextRequest( + model="model_value", + texts=['texts_value1', 'texts_value2'], + ) + + # Make the request + response = await client.batch_embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.BatchEmbedTextRequest, dict]]): + The request object. Batch request to get a text embedding + from the model. + model (:class:`str`): + Required. The name of the ``Model`` to use for + generating the embedding. Examples: + models/embedding-gecko-001 + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + texts (:class:`MutableSequence[str]`): + Required. The free-form input texts + that the model will turn into an + embedding. The current limit is 100 + texts, over which an error will be + thrown. + + This corresponds to the ``texts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.BatchEmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, texts]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = text_service.BatchEmbedTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if texts: + request.texts.extend(texts) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_embed_text, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def count_text_tokens(self, + request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.CountTextTokensResponse: + r"""Runs a model's tokenizer on a text and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + async def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_text_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1beta3.types.CountTextTokensRequest, dict]]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1beta3.types.TextPrompt`): + Required. The free-form input text + given to the model as a prompt. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.CountTextTokensResponse: + A response from CountTextTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError("If the `request` argument is set, then none of " + "the individual field arguments should be set.") + + request = text_service.CountTextTokensRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.count_text_tokens, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "TextServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "TextServiceAsyncClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/client.py new file mode 100644 index 000000000000..ed05565059cc --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/client.py @@ -0,0 +1,968 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import os +import re +from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + +from google.ai.generativelanguage_v1beta3.types import safety +from google.ai.generativelanguage_v1beta3.types import text_service +from google.longrunning import operations_pb2 # type: ignore +from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import TextServiceGrpcTransport +from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .transports.rest import TextServiceRestTransport + + +class TextServiceClientMeta(type): + """Metaclass for the TextService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + _transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] + _transport_registry["grpc"] = TextServiceGrpcTransport + _transport_registry["grpc_asyncio"] = TextServiceGrpcAsyncIOTransport + _transport_registry["rest"] = TextServiceRestTransport + + def get_transport_class(cls, + label: Optional[str] = None, + ) -> Type[TextServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class TextServiceClient(metaclass=TextServiceClientMeta): + """API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file( + filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> TextServiceTransport: + """Returns the transport used by the client instance. + + Returns: + TextServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path(model: str,) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format(model=model, ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str,str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str, ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str,str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str, ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder, ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str,str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str, ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization, ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str,str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str, ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format(project=project, ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str,str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str, ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format(project=project, location=location, ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str,str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + def __init__(self, *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, TextServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the text service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, TextServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + client_options = cast(client_options_lib.ClientOptions, client_options) + + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + + api_key_value = getattr(client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError("client_options.api_key and credentials are mutually exclusive") + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, TextServiceTransport): + # transport is a TextServiceTransport instance. + if credentials or client_options.credentials_file or api_key_value: + raise ValueError("When providing a transport instance, " + "provide its credentials directly.") + if client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = transport + else: + import google.auth._default # type: ignore + + if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): + credentials = google.auth._default.get_api_key_credentials(api_key_value) + + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=client_options.api_audience, + ) + + def generate_text(self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.GenerateTextResponse: + r"""Generates a response from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.GenerateTextRequest, dict]): + The request object. Request to generate a text completion + response from the model. + model (str): + Required. The name of the ``Model`` or ``TunedModel`` to + use for generating the completion. Examples: + models/text-bison-001 + tunedModels/sentence-translator-u3b7m + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1beta3.types.TextPrompt): + Required. The free-form input text + given to the model as a prompt. + Given a prompt, the model will generate + a TextCompletion response it predicts as + the completion of the input text. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (float): + Optional. Controls the randomness of the output. Note: + The default value varies by model, see the + ``Model.temperature`` attribute of the ``Model`` + returned the ``getModel`` function. + + Values can range from [0.0,1.0], inclusive. A value + closer to 1.0 will produce responses that are more + varied and creative, while a value closer to 0.0 will + typically result in more straightforward responses from + the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (int): + Optional. Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, + this will default to 1. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + max_output_tokens (int): + Optional. The maximum number of tokens to include in a + candidate. + + If unset, this will default to output_token_limit + specified in the ``Model`` specification. + + This corresponds to the ``max_output_tokens`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (float): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities + so that only the most likely tokens are considered. + Top-k sampling directly limits the maximum number of + tokens to consider, while Nucleus sampling limits number + of tokens based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a text_service.GenerateTextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, text_service.GenerateTextRequest): + request = text_service.GenerateTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if max_output_tokens is not None: + request.max_output_tokens = max_output_tokens + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def embed_text(self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.EmbedTextResponse: + r"""Generates an embedding from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = client.embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.EmbedTextRequest, dict]): + The request object. Request to get a text embedding from + the model. + model (str): + Required. The model name to use with + the format model=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + text (str): + Required. The free-form input text + that the model will turn into an + embedding. + + This corresponds to the ``text`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, text]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a text_service.EmbedTextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, text_service.EmbedTextRequest): + request = text_service.EmbedTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if text is not None: + request.text = text + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.embed_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_embed_text(self, + request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + texts: Optional[MutableSequence[str]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.BatchEmbedTextResponse: + r"""Generates multiple embeddings from the model given + input text in a synchronous call. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.BatchEmbedTextRequest( + model="model_value", + texts=['texts_value1', 'texts_value2'], + ) + + # Make the request + response = client.batch_embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.BatchEmbedTextRequest, dict]): + The request object. Batch request to get a text embedding + from the model. + model (str): + Required. The name of the ``Model`` to use for + generating the embedding. Examples: + models/embedding-gecko-001 + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + texts (MutableSequence[str]): + Required. The free-form input texts + that the model will turn into an + embedding. The current limit is 100 + texts, over which an error will be + thrown. + + This corresponds to the ``texts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.BatchEmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, texts]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a text_service.BatchEmbedTextRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, text_service.BatchEmbedTextRequest): + request = text_service.BatchEmbedTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if texts is not None: + request.texts = texts + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_embed_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def count_text_tokens(self, + request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.CountTextTokensResponse: + r"""Runs a model's tokenizer on a text and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1beta3 + + def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_text_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1beta3.types.CountTextTokensRequest, dict]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1beta3.types.TextPrompt): + Required. The free-form input text + given to the model as a prompt. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.ai.generativelanguage_v1beta3.types.CountTextTokensResponse: + A response from CountTextTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError('If the `request` argument is set, then none of ' + 'the individual field arguments should be set.') + + # Minor optimization to avoid making a copy if the user passes + # in a text_service.CountTextTokensRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, text_service.CountTextTokensRequest): + request = text_service.CountTextTokensRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.count_text_tokens] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("model", request.model), + )), + ) + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "TextServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + + + + + + + + + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +__all__ = ( + "TextServiceClient", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py new file mode 100644 index 000000000000..71e949c7a4f5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import TextServiceTransport +from .grpc import TextServiceGrpcTransport +from .grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .rest import TextServiceRestTransport +from .rest import TextServiceRestInterceptor + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] +_transport_registry['grpc'] = TextServiceGrpcTransport +_transport_registry['grpc_asyncio'] = TextServiceGrpcAsyncIOTransport +_transport_registry['rest'] = TextServiceRestTransport + +__all__ = ( + 'TextServiceTransport', + 'TextServiceGrpcTransport', + 'TextServiceGrpcAsyncIOTransport', + 'TextServiceRestTransport', + 'TextServiceRestInterceptor', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py new file mode 100644 index 000000000000..ab8ddc3e423d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py @@ -0,0 +1,190 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +import google.auth # type: ignore +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3.types import text_service +from google.longrunning import operations_pb2 # type: ignore + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) + + +class TextServiceTransport(abc.ABC): + """Abstract transport class for TextService.""" + + AUTH_SCOPES = ( + ) + + DEFAULT_HOST: str = 'generativelanguage.googleapis.com' + def __init__( + self, *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, + **scopes_kwargs, + quota_project_id=quota_project_id + ) + elif credentials is None: + credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ':' not in host: + host += ':443' + self._host = host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.generate_text: gapic_v1.method.wrap_method( + self.generate_text, + default_timeout=None, + client_info=client_info, + ), + self.embed_text: gapic_v1.method.wrap_method( + self.embed_text, + default_timeout=None, + client_info=client_info, + ), + self.batch_embed_text: gapic_v1.method.wrap_method( + self.batch_embed_text, + default_timeout=None, + client_info=client_info, + ), + self.count_text_tokens: gapic_v1.method.wrap_method( + self.count_text_tokens, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + Union[ + text_service.GenerateTextResponse, + Awaitable[text_service.GenerateTextResponse] + ]]: + raise NotImplementedError() + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + Union[ + text_service.EmbedTextResponse, + Awaitable[text_service.EmbedTextResponse] + ]]: + raise NotImplementedError() + + @property + def batch_embed_text(self) -> Callable[ + [text_service.BatchEmbedTextRequest], + Union[ + text_service.BatchEmbedTextResponse, + Awaitable[text_service.BatchEmbedTextResponse] + ]]: + raise NotImplementedError() + + @property + def count_text_tokens(self) -> Callable[ + [text_service.CountTextTokensRequest], + Union[ + text_service.CountTextTokensResponse, + Awaitable[text_service.CountTextTokensResponse] + ]]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ( + 'TextServiceTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py new file mode 100644 index 000000000000..d3b0615ad633 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import grpc_helpers +from google.api_core import gapic_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.ai.generativelanguage_v1beta3.types import text_service +from google.longrunning import operations_pb2 # type: ignore +from .base import TextServiceTransport, DEFAULT_CLIENT_INFO + + +class TextServiceGrpcTransport(TextServiceTransport): + """gRPC backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + _stubs: Dict[str, Callable] + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + text_service.GenerateTextResponse]: + r"""Return a callable for the generate text method over gRPC. + + Generates a response from the model given an input + message. + + Returns: + Callable[[~.GenerateTextRequest], + ~.GenerateTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_text' not in self._stubs: + self._stubs['generate_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/GenerateText', + request_serializer=text_service.GenerateTextRequest.serialize, + response_deserializer=text_service.GenerateTextResponse.deserialize, + ) + return self._stubs['generate_text'] + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + text_service.EmbedTextResponse]: + r"""Return a callable for the embed text method over gRPC. + + Generates an embedding from the model given an input + message. + + Returns: + Callable[[~.EmbedTextRequest], + ~.EmbedTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'embed_text' not in self._stubs: + self._stubs['embed_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/EmbedText', + request_serializer=text_service.EmbedTextRequest.serialize, + response_deserializer=text_service.EmbedTextResponse.deserialize, + ) + return self._stubs['embed_text'] + + @property + def batch_embed_text(self) -> Callable[ + [text_service.BatchEmbedTextRequest], + text_service.BatchEmbedTextResponse]: + r"""Return a callable for the batch embed text method over gRPC. + + Generates multiple embeddings from the model given + input text in a synchronous call. + + Returns: + Callable[[~.BatchEmbedTextRequest], + ~.BatchEmbedTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_embed_text' not in self._stubs: + self._stubs['batch_embed_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/BatchEmbedText', + request_serializer=text_service.BatchEmbedTextRequest.serialize, + response_deserializer=text_service.BatchEmbedTextResponse.deserialize, + ) + return self._stubs['batch_embed_text'] + + @property + def count_text_tokens(self) -> Callable[ + [text_service.CountTextTokensRequest], + text_service.CountTextTokensResponse]: + r"""Return a callable for the count text tokens method over gRPC. + + Runs a model's tokenizer on a text and returns the + token count. + + Returns: + Callable[[~.CountTextTokensRequest], + ~.CountTextTokensResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'count_text_tokens' not in self._stubs: + self._stubs['count_text_tokens'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/CountTextTokens', + request_serializer=text_service.CountTextTokensRequest.serialize, + response_deserializer=text_service.CountTextTokensResponse.deserialize, + ) + return self._stubs['count_text_tokens'] + + def close(self): + self.grpc_channel.close() + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ( + 'TextServiceGrpcTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..46ac7dc2d417 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py @@ -0,0 +1,349 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.ai.generativelanguage_v1beta3.types import text_service +from google.longrunning import operations_pb2 # type: ignore +from .base import TextServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import TextServiceGrpcTransport + + +class TextServiceGrpcAsyncIOTransport(TextServiceTransport): + """gRPC AsyncIO backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel(cls, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs + ) + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Ignore credentials if a channel was passed. + credentials = False + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + self._grpc_channel = type(self).create_channel( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Wrap messages. This must be done after self._grpc_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + Awaitable[text_service.GenerateTextResponse]]: + r"""Return a callable for the generate text method over gRPC. + + Generates a response from the model given an input + message. + + Returns: + Callable[[~.GenerateTextRequest], + Awaitable[~.GenerateTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'generate_text' not in self._stubs: + self._stubs['generate_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/GenerateText', + request_serializer=text_service.GenerateTextRequest.serialize, + response_deserializer=text_service.GenerateTextResponse.deserialize, + ) + return self._stubs['generate_text'] + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + Awaitable[text_service.EmbedTextResponse]]: + r"""Return a callable for the embed text method over gRPC. + + Generates an embedding from the model given an input + message. + + Returns: + Callable[[~.EmbedTextRequest], + Awaitable[~.EmbedTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'embed_text' not in self._stubs: + self._stubs['embed_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/EmbedText', + request_serializer=text_service.EmbedTextRequest.serialize, + response_deserializer=text_service.EmbedTextResponse.deserialize, + ) + return self._stubs['embed_text'] + + @property + def batch_embed_text(self) -> Callable[ + [text_service.BatchEmbedTextRequest], + Awaitable[text_service.BatchEmbedTextResponse]]: + r"""Return a callable for the batch embed text method over gRPC. + + Generates multiple embeddings from the model given + input text in a synchronous call. + + Returns: + Callable[[~.BatchEmbedTextRequest], + Awaitable[~.BatchEmbedTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'batch_embed_text' not in self._stubs: + self._stubs['batch_embed_text'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/BatchEmbedText', + request_serializer=text_service.BatchEmbedTextRequest.serialize, + response_deserializer=text_service.BatchEmbedTextResponse.deserialize, + ) + return self._stubs['batch_embed_text'] + + @property + def count_text_tokens(self) -> Callable[ + [text_service.CountTextTokensRequest], + Awaitable[text_service.CountTextTokensResponse]]: + r"""Return a callable for the count text tokens method over gRPC. + + Runs a model's tokenizer on a text and returns the + token count. + + Returns: + Callable[[~.CountTextTokensRequest], + Awaitable[~.CountTextTokensResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'count_text_tokens' not in self._stubs: + self._stubs['count_text_tokens'] = self.grpc_channel.unary_unary( + '/google.ai.generativelanguage.v1beta3.TextService/CountTextTokens', + request_serializer=text_service.CountTextTokensRequest.serialize, + response_deserializer=text_service.CountTextTokensResponse.deserialize, + ) + return self._stubs['count_text_tokens'] + + def close(self): + return self.grpc_channel.close() + + +__all__ = ( + 'TextServiceGrpcAsyncIOTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py new file mode 100644 index 000000000000..cdd184d866a4 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py @@ -0,0 +1,676 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.auth.transport.requests import AuthorizedSession # type: ignore +import json # type: ignore +import grpc # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.api_core import rest_helpers +from google.api_core import rest_streaming +from google.api_core import path_template +from google.api_core import gapic_v1 + +from google.protobuf import json_format +from requests import __version__ as requests_version +import dataclasses +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object] # type: ignore + + +from google.ai.generativelanguage_v1beta3.types import text_service +from google.longrunning import operations_pb2 # type: ignore + +from .base import TextServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=requests_version, +) + + +class TextServiceRestInterceptor: + """Interceptor for TextService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the TextServiceRestTransport. + + .. code-block:: python + class MyCustomTextServiceInterceptor(TextServiceRestInterceptor): + def pre_batch_embed_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_batch_embed_text(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_count_text_tokens(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_count_text_tokens(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_embed_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_embed_text(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_text(self, response): + logging.log(f"Received response: {response}") + return response + + transport = TextServiceRestTransport(interceptor=MyCustomTextServiceInterceptor()) + client = TextServiceClient(transport=transport) + + + """ + def pre_batch_embed_text(self, request: text_service.BatchEmbedTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.BatchEmbedTextRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for batch_embed_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_batch_embed_text(self, response: text_service.BatchEmbedTextResponse) -> text_service.BatchEmbedTextResponse: + """Post-rpc interceptor for batch_embed_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + def pre_count_text_tokens(self, request: text_service.CountTextTokensRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.CountTextTokensRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for count_text_tokens + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_count_text_tokens(self, response: text_service.CountTextTokensResponse) -> text_service.CountTextTokensResponse: + """Post-rpc interceptor for count_text_tokens + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + def pre_embed_text(self, request: text_service.EmbedTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.EmbedTextRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for embed_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_embed_text(self, response: text_service.EmbedTextResponse) -> text_service.EmbedTextResponse: + """Post-rpc interceptor for embed_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + def pre_generate_text(self, request: text_service.GenerateTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.GenerateTextRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for generate_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_generate_text(self, response: text_service.GenerateTextResponse) -> text_service.GenerateTextResponse: + """Post-rpc interceptor for generate_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class TextServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: TextServiceRestInterceptor + + +class TextServiceRestTransport(TextServiceTransport): + """REST backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + + """ + + def __init__(self, *, + host: str = 'generativelanguage.googleapis.com', + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[ + ], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = 'https', + interceptor: Optional[TextServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or TextServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _BatchEmbedText(TextServiceRestStub): + def __hash__(self): + return hash("BatchEmbedText") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: text_service.BatchEmbedTextRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> text_service.BatchEmbedTextResponse: + r"""Call the batch embed text method over HTTP. + + Args: + request (~.text_service.BatchEmbedTextRequest): + The request object. Batch request to get a text embedding + from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.text_service.BatchEmbedTextResponse: + The response to a EmbedTextRequest. + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{model=models/*}:batchEmbedText', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_batch_embed_text(request, metadata) + pb_request = text_service.BatchEmbedTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.BatchEmbedTextResponse() + pb_resp = text_service.BatchEmbedTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_batch_embed_text(resp) + return resp + + class _CountTextTokens(TextServiceRestStub): + def __hash__(self): + return hash("CountTextTokens") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: text_service.CountTextTokensRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> text_service.CountTextTokensResponse: + r"""Call the count text tokens method over HTTP. + + Args: + request (~.text_service.CountTextTokensRequest): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.text_service.CountTextTokensResponse: + A response from ``CountTextTokens``. + + It returns the model's ``token_count`` for the + ``prompt``. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{model=models/*}:countTextTokens', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_count_text_tokens(request, metadata) + pb_request = text_service.CountTextTokensRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.CountTextTokensResponse() + pb_resp = text_service.CountTextTokensResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_count_text_tokens(resp) + return resp + + class _EmbedText(TextServiceRestStub): + def __hash__(self): + return hash("EmbedText") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: text_service.EmbedTextRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> text_service.EmbedTextResponse: + r"""Call the embed text method over HTTP. + + Args: + request (~.text_service.EmbedTextRequest): + The request object. Request to get a text embedding from + the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.text_service.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{model=models/*}:embedText', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_embed_text(request, metadata) + pb_request = text_service.EmbedTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.EmbedTextResponse() + pb_resp = text_service.EmbedTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_embed_text(resp) + return resp + + class _GenerateText(TextServiceRestStub): + def __hash__(self): + return hash("GenerateText") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + def __call__(self, + request: text_service.GenerateTextRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> text_service.GenerateTextResponse: + r"""Call the generate text method over HTTP. + + Args: + request (~.text_service.GenerateTextRequest): + The request object. Request to generate a text completion + response from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.text_service.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta3/{model=models/*}:generateText', + 'body': '*', + }, +{ + 'method': 'post', + 'uri': '/v1beta3/{model=tunedModels/*}:generateText', + 'body': '*', + }, + ] + request, metadata = self._interceptor.pre_generate_text(request, metadata) + pb_request = text_service.GenerateTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + including_default_value_fields=False, + use_integers_for_enums=True + ) + uri = transcoded_request['uri'] + method = transcoded_request['method'] + + # Jsonify the query params + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + including_default_value_fields=False, + use_integers_for_enums=True, + )) + query_params.update(self._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + + # Send the request + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(self._session, method)( + "{host}{uri}".format(host=self._host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.GenerateTextResponse() + pb_resp = text_service.GenerateTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + resp = self._interceptor.post_generate_text(resp) + return resp + + @property + def batch_embed_text(self) -> Callable[ + [text_service.BatchEmbedTextRequest], + text_service.BatchEmbedTextResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BatchEmbedText(self._session, self._host, self._interceptor) # type: ignore + + @property + def count_text_tokens(self) -> Callable[ + [text_service.CountTextTokensRequest], + text_service.CountTextTokensResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CountTextTokens(self._session, self._host, self._interceptor) # type: ignore + + @property + def embed_text(self) -> Callable[ + [text_service.EmbedTextRequest], + text_service.EmbedTextResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._EmbedText(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_text(self) -> Callable[ + [text_service.GenerateTextRequest], + text_service.GenerateTextResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateText(self._session, self._host, self._interceptor) # type: ignore + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__=( + 'TextServiceRestTransport', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/__init__.py new file mode 100644 index 000000000000..b2a054b00c36 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/__init__.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .citation import ( + CitationMetadata, + CitationSource, +) +from .discuss_service import ( + CountMessageTokensRequest, + CountMessageTokensResponse, + Example, + GenerateMessageRequest, + GenerateMessageResponse, + Message, + MessagePrompt, +) +from .model import ( + Model, +) +from .model_service import ( + CreateTunedModelMetadata, + CreateTunedModelRequest, + DeleteTunedModelRequest, + GetModelRequest, + GetTunedModelRequest, + ListModelsRequest, + ListModelsResponse, + ListTunedModelsRequest, + ListTunedModelsResponse, + UpdateTunedModelRequest, +) +from .permission import ( + Permission, +) +from .permission_service import ( + CreatePermissionRequest, + DeletePermissionRequest, + GetPermissionRequest, + ListPermissionsRequest, + ListPermissionsResponse, + TransferOwnershipRequest, + TransferOwnershipResponse, + UpdatePermissionRequest, +) +from .safety import ( + ContentFilter, + SafetyFeedback, + SafetyRating, + SafetySetting, + HarmCategory, +) +from .text_service import ( + BatchEmbedTextRequest, + BatchEmbedTextResponse, + CountTextTokensRequest, + CountTextTokensResponse, + Embedding, + EmbedTextRequest, + EmbedTextResponse, + GenerateTextRequest, + GenerateTextResponse, + TextCompletion, + TextPrompt, +) +from .tuned_model import ( + Dataset, + Hyperparameters, + TunedModel, + TunedModelSource, + TuningExample, + TuningExamples, + TuningSnapshot, + TuningTask, +) + +__all__ = ( + 'CitationMetadata', + 'CitationSource', + 'CountMessageTokensRequest', + 'CountMessageTokensResponse', + 'Example', + 'GenerateMessageRequest', + 'GenerateMessageResponse', + 'Message', + 'MessagePrompt', + 'Model', + 'CreateTunedModelMetadata', + 'CreateTunedModelRequest', + 'DeleteTunedModelRequest', + 'GetModelRequest', + 'GetTunedModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'ListTunedModelsRequest', + 'ListTunedModelsResponse', + 'UpdateTunedModelRequest', + 'Permission', + 'CreatePermissionRequest', + 'DeletePermissionRequest', + 'GetPermissionRequest', + 'ListPermissionsRequest', + 'ListPermissionsResponse', + 'TransferOwnershipRequest', + 'TransferOwnershipResponse', + 'UpdatePermissionRequest', + 'ContentFilter', + 'SafetyFeedback', + 'SafetyRating', + 'SafetySetting', + 'HarmCategory', + 'BatchEmbedTextRequest', + 'BatchEmbedTextResponse', + 'CountTextTokensRequest', + 'CountTextTokensResponse', + 'Embedding', + 'EmbedTextRequest', + 'EmbedTextResponse', + 'GenerateTextRequest', + 'GenerateTextResponse', + 'TextCompletion', + 'TextPrompt', + 'Dataset', + 'Hyperparameters', + 'TunedModel', + 'TunedModelSource', + 'TuningExample', + 'TuningExamples', + 'TuningSnapshot', + 'TuningTask', +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/citation.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/citation.py new file mode 100644 index 000000000000..f7ea0d176c60 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/citation.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'CitationMetadata', + 'CitationSource', + }, +) + + +class CitationMetadata(proto.Message): + r"""A collection of source attributions for a piece of content. + + Attributes: + citation_sources (MutableSequence[google.ai.generativelanguage_v1beta3.types.CitationSource]): + Citations to sources for a specific response. + """ + + citation_sources: MutableSequence['CitationSource'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='CitationSource', + ) + + +class CitationSource(proto.Message): + r"""A citation to a source for a portion of a specific response. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + start_index (int): + Optional. Start of segment of the response + that is attributed to this source. + + Index indicates the start of the segment, + measured in bytes. + + This field is a member of `oneof`_ ``_start_index``. + end_index (int): + Optional. End of the attributed segment, + exclusive. + + This field is a member of `oneof`_ ``_end_index``. + uri (str): + Optional. URI that is attributed as a source + for a portion of the text. + + This field is a member of `oneof`_ ``_uri``. + license_ (str): + Optional. License for the GitHub project that + is attributed as a source for segment. + + License info is required for code citations. + + This field is a member of `oneof`_ ``_license``. + """ + + start_index: int = proto.Field( + proto.INT32, + number=1, + optional=True, + ) + end_index: int = proto.Field( + proto.INT32, + number=2, + optional=True, + ) + uri: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + license_: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/discuss_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/discuss_service.py new file mode 100644 index 000000000000..4c731553dbbc --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/discuss_service.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1beta3.types import citation +from google.ai.generativelanguage_v1beta3.types import safety + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'GenerateMessageRequest', + 'GenerateMessageResponse', + 'Message', + 'MessagePrompt', + 'Example', + 'CountMessageTokensRequest', + 'CountMessageTokensResponse', + }, +) + + +class GenerateMessageRequest(proto.Message): + r"""Request to generate a message response from the model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + prompt (google.ai.generativelanguage_v1beta3.types.MessagePrompt): + Required. The structured textual input given + to the model as a prompt. + Given a + prompt, the model will return what it predicts + is the next message in the discussion. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This field is a member of `oneof`_ ``_temperature``. + candidate_count (int): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If unset, + this will default to ``1``. + + This field is a member of `oneof`_ ``_candidate_count``. + top_p (float): + Optional. The maximum cumulative probability of tokens to + consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. + + This field is a member of `oneof`_ ``_top_k``. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: 'MessagePrompt' = proto.Field( + proto.MESSAGE, + number=2, + message='MessagePrompt', + ) + temperature: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + candidate_count: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=5, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=6, + optional=True, + ) + + +class GenerateMessageResponse(proto.Message): + r"""The response from the model. + + This includes candidate messages and + conversation history in the form of chronologically-ordered + messages. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1beta3.types.Message]): + Candidate response messages from the model. + messages (MutableSequence[google.ai.generativelanguage_v1beta3.types.Message]): + The conversation history used by the model. + filters (MutableSequence[google.ai.generativelanguage_v1beta3.types.ContentFilter]): + A set of content filtering metadata for the prompt and + response text. + + This indicates which ``SafetyCategory``\ (s) blocked a + candidate from this response, the lowest ``HarmProbability`` + that triggered a block, and the HarmThreshold setting for + that category. + """ + + candidates: MutableSequence['Message'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='Message', + ) + messages: MutableSequence['Message'] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message='Message', + ) + filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.ContentFilter, + ) + + +class Message(proto.Message): + r"""The base unit of structured text. + + A ``Message`` includes an ``author`` and the ``content`` of the + ``Message``. + + The ``author`` is used to tag messages when they are fed to the + model as text. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + author (str): + Optional. The author of this Message. + + This serves as a key for tagging + the content of this Message when it is fed to + the model as text. + + The author can be any alphanumeric string. + content (str): + Required. The text content of the structured ``Message``. + citation_metadata (google.ai.generativelanguage_v1beta3.types.CitationMetadata): + Output only. Citation information for model-generated + ``content`` in this ``Message``. + + If this ``Message`` was generated as output from the model, + this field may be populated with attribution information for + any text included in the ``content``. This field is used + only on output. + + This field is a member of `oneof`_ ``_citation_metadata``. + """ + + author: str = proto.Field( + proto.STRING, + number=1, + ) + content: str = proto.Field( + proto.STRING, + number=2, + ) + citation_metadata: citation.CitationMetadata = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=citation.CitationMetadata, + ) + + +class MessagePrompt(proto.Message): + r"""All of the structured input text passed to the model as a prompt. + + A ``MessagePrompt`` contains a structured set of fields that provide + context for the conversation, examples of user input/model output + message pairs that prime the model to respond in different ways, and + the conversation history or list of messages representing the + alternating turns of the conversation between the user and the + model. + + Attributes: + context (str): + Optional. Text that should be provided to the model first to + ground the response. + + If not empty, this ``context`` will be given to the model + first before the ``examples`` and ``messages``. When using a + ``context`` be sure to provide it with every request to + maintain continuity. + + This field can be a description of your prompt to the model + to help provide context and guide the responses. Examples: + "Translate the phrase from English to French." or "Given a + statement, classify the sentiment as happy, sad or neutral." + + Anything included in this field will take precedence over + message history if the total input size exceeds the model's + ``input_token_limit`` and the input request is truncated. + examples (MutableSequence[google.ai.generativelanguage_v1beta3.types.Example]): + Optional. Examples of what the model should generate. + + This includes both user input and the response that the + model should emulate. + + These ``examples`` are treated identically to conversation + messages except that they take precedence over the history + in ``messages``: If the total input size exceeds the model's + ``input_token_limit`` the input will be truncated. Items + will be dropped from ``messages`` before ``examples``. + messages (MutableSequence[google.ai.generativelanguage_v1beta3.types.Message]): + Required. A snapshot of the recent conversation history + sorted chronologically. + + Turns alternate between two authors. + + If the total input size exceeds the model's + ``input_token_limit`` the input will be truncated: The + oldest items will be dropped from ``messages``. + """ + + context: str = proto.Field( + proto.STRING, + number=1, + ) + examples: MutableSequence['Example'] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message='Example', + ) + messages: MutableSequence['Message'] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message='Message', + ) + + +class Example(proto.Message): + r"""An input/output example used to instruct the Model. + + It demonstrates how the model should respond or format its + response. + + Attributes: + input (google.ai.generativelanguage_v1beta3.types.Message): + Required. An example of an input ``Message`` from the user. + output (google.ai.generativelanguage_v1beta3.types.Message): + Required. An example of what the model should + output given the input. + """ + + input: 'Message' = proto.Field( + proto.MESSAGE, + number=1, + message='Message', + ) + output: 'Message' = proto.Field( + proto.MESSAGE, + number=2, + message='Message', + ) + + +class CountMessageTokensRequest(proto.Message): + r"""Counts the number of tokens in the ``prompt`` sent to a model. + + Models may tokenize text differently, so each model may return a + different ``token_count``. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + prompt (google.ai.generativelanguage_v1beta3.types.MessagePrompt): + Required. The prompt, whose token count is to + be returned. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: 'MessagePrompt' = proto.Field( + proto.MESSAGE, + number=2, + message='MessagePrompt', + ) + + +class CountMessageTokensResponse(proto.Message): + r"""A response from ``CountMessageTokens``. + + It returns the model's ``token_count`` for the ``prompt``. + + Attributes: + token_count (int): + The number of tokens that the ``model`` tokenizes the + ``prompt`` into. + + Always non-negative. + """ + + token_count: int = proto.Field( + proto.INT32, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model.py new file mode 100644 index 000000000000..f5ac72ce872b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'Model', + }, +) + + +class Model(proto.Message): + r"""Information about a Generative Language Model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + name (str): + Required. The resource name of the ``Model``. + + Format: ``models/{model}`` with a ``{model}`` naming + convention of: + + - "{base_model_id}-{version}" + + Examples: + + - ``models/chat-bison-001`` + base_model_id (str): + Required. The name of the base model, pass this to the + generation request. + + Examples: + + - ``chat-bison`` + version (str): + Required. The version number of the model. + + This represents the major version + display_name (str): + The human-readable name of the model. E.g. + "Chat Bison". + The name can be up to 128 characters long and + can consist of any UTF-8 characters. + description (str): + A short description of the model. + input_token_limit (int): + Maximum number of input tokens allowed for + this model. + output_token_limit (int): + Maximum number of output tokens available for + this model. + supported_generation_methods (MutableSequence[str]): + The model's supported generation methods. + + The method names are defined as Pascal case strings, such as + ``generateMessage`` which correspond to API methods. + temperature (float): + Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. This + value specifies default to be used by the backend while + making the call to the model. + + This field is a member of `oneof`_ ``_temperature``. + top_p (float): + For Nucleus sampling. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. This value specifies + default to be used by the backend while making the call to + the model. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + For Top-k sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. This value specifies default to be used by the + backend while making the call to the model. + + This field is a member of `oneof`_ ``_top_k``. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + base_model_id: str = proto.Field( + proto.STRING, + number=2, + ) + version: str = proto.Field( + proto.STRING, + number=3, + ) + display_name: str = proto.Field( + proto.STRING, + number=4, + ) + description: str = proto.Field( + proto.STRING, + number=5, + ) + input_token_limit: int = proto.Field( + proto.INT32, + number=6, + ) + output_token_limit: int = proto.Field( + proto.INT32, + number=7, + ) + supported_generation_methods: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=8, + ) + temperature: float = proto.Field( + proto.FLOAT, + number=9, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=10, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=11, + optional=True, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model_service.py new file mode 100644 index 000000000000..f2f640d9da76 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model_service.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.protobuf import field_mask_pb2 # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'GetModelRequest', + 'ListModelsRequest', + 'ListModelsResponse', + 'GetTunedModelRequest', + 'ListTunedModelsRequest', + 'ListTunedModelsResponse', + 'CreateTunedModelRequest', + 'CreateTunedModelMetadata', + 'UpdateTunedModelRequest', + 'DeleteTunedModelRequest', + }, +) + + +class GetModelRequest(proto.Message): + r"""Request for getting information about a specific Model. + + Attributes: + name (str): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListModelsRequest(proto.Message): + r"""Request for listing all Models. + + Attributes: + page_size (int): + The maximum number of ``Models`` to return (per page). + + The service may return fewer models. If unspecified, at most + 50 models will be returned per page. This method returns at + most 1000 models per page, even if you pass a larger + page_size. + page_token (str): + A page token, received from a previous ``ListModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the page + token. + """ + + page_size: int = proto.Field( + proto.INT32, + number=2, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListModelsResponse(proto.Message): + r"""Response from ``ListModel`` containing a paginated list of Models. + + Attributes: + models (MutableSequence[google.ai.generativelanguage_v1beta3.types.Model]): + The returned Models. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. + + If this field is omitted, there are no more pages. + """ + + @property + def raw_page(self): + return self + + models: MutableSequence[model.Model] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=model.Model, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class GetTunedModelRequest(proto.Message): + r"""Request for getting information about a specific Model. + + Attributes: + name (str): + Required. The resource name of the model. + + Format: ``tunedModels/my-model-id`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListTunedModelsRequest(proto.Message): + r"""Request for listing TunedModels. + + Attributes: + page_size (int): + Optional. The maximum number of ``TunedModels`` to return + (per page). The service may return fewer tuned models. + + If unspecified, at most 10 tuned models will be returned. + This method returns at most 1000 models per page, even if + you pass a larger page_size. + page_token (str): + Optional. A page token, received from a previous + ``ListTunedModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListTunedModels`` must match the call that provided the + page token. + """ + + page_size: int = proto.Field( + proto.INT32, + number=1, + ) + page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class ListTunedModelsResponse(proto.Message): + r"""Response from ``ListTunedModels`` containing a paginated list of + Models. + + Attributes: + tuned_models (MutableSequence[google.ai.generativelanguage_v1beta3.types.TunedModel]): + The returned Models. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. + + If this field is omitted, there are no more pages. + """ + + @property + def raw_page(self): + return self + + tuned_models: MutableSequence[gag_tuned_model.TunedModel] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_tuned_model.TunedModel, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class CreateTunedModelRequest(proto.Message): + r"""Request to create a TunedModel. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + tuned_model_id (str): + Optional. The unique id for the tuned model if specified. + This value should be up to 40 characters, the first + character must be a letter, the last could be a letter or a + number. The id must match the regular expression: + `a-z <[a-z0-9-]{0,38}[a-z0-9]>`__?. + + This field is a member of `oneof`_ ``_tuned_model_id``. + tuned_model (google.ai.generativelanguage_v1beta3.types.TunedModel): + Required. The tuned model to create. + """ + + tuned_model_id: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + tuned_model: gag_tuned_model.TunedModel = proto.Field( + proto.MESSAGE, + number=2, + message=gag_tuned_model.TunedModel, + ) + + +class CreateTunedModelMetadata(proto.Message): + r"""Metadata about the state and progress of creating a tuned + model returned from the long-running operation + + Attributes: + tuned_model (str): + Name of the tuned model associated with the + tuning operation. + total_steps (int): + The total number of tuning steps. + completed_steps (int): + The number of steps completed. + completed_percent (float): + The completed percentage for the tuning + operation. + snapshots (MutableSequence[google.ai.generativelanguage_v1beta3.types.TuningSnapshot]): + Metrics collected during tuning. + """ + + tuned_model: str = proto.Field( + proto.STRING, + number=5, + ) + total_steps: int = proto.Field( + proto.INT32, + number=1, + ) + completed_steps: int = proto.Field( + proto.INT32, + number=2, + ) + completed_percent: float = proto.Field( + proto.FLOAT, + number=3, + ) + snapshots: MutableSequence[gag_tuned_model.TuningSnapshot] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=gag_tuned_model.TuningSnapshot, + ) + + +class UpdateTunedModelRequest(proto.Message): + r"""Request to update a TunedModel. + + Attributes: + tuned_model (google.ai.generativelanguage_v1beta3.types.TunedModel): + Required. The tuned model to update. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. + """ + + tuned_model: gag_tuned_model.TunedModel = proto.Field( + proto.MESSAGE, + number=1, + message=gag_tuned_model.TunedModel, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class DeleteTunedModelRequest(proto.Message): + r"""Request to delete a TunedModel. + + Attributes: + name (str): + Required. The resource name of the model. Format: + ``tunedModels/my-model-id`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission.py new file mode 100644 index 000000000000..7a5b9c7c14b3 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'Permission', + }, +) + + +class Permission(proto.Message): + r"""Permission resource grants user, group or the rest of the + world access to the PaLM API resource (e.g. a tuned model, + file). + + A role is a collection of permitted operations that allows users + to perform specific actions on PaLM API resources. To make them + available to users, groups, or service accounts, you assign + roles. When you assign a role, you grant permissions that the + role contains. + + There are three concentric roles. Each role is a superset of the + previous role's permitted operations: + + - reader can use the resource (e.g. tuned model) for inference + - writer has reader's permissions and additionally can edit and + share + - owner has writer's permissions and additionally can delete + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + name (str): + Output only. The permission name. A unique name will be + generated on create. Example: + tunedModels/{tuned_model}permssions/{permission} Output + only. + grantee_type (google.ai.generativelanguage_v1beta3.types.Permission.GranteeType): + Required. Immutable. The type of the grantee. + + This field is a member of `oneof`_ ``_grantee_type``. + email_address (str): + Optional. Immutable. The email address of the + user of group which this permission refers. + Field is not set when permission's grantee type + is EVERYONE. + + This field is a member of `oneof`_ ``_email_address``. + role (google.ai.generativelanguage_v1beta3.types.Permission.Role): + Required. The role granted by this + permission. + + This field is a member of `oneof`_ ``_role``. + """ + class GranteeType(proto.Enum): + r"""Defines types of the grantee of this permission. + + Values: + GRANTEE_TYPE_UNSPECIFIED (0): + The default value. This value is unused. + USER (1): + Represents a user. When set, you must provide email_address + for the user. + GROUP (2): + Represents a group. When set, you must provide email_address + for the group. + EVERYONE (3): + Represents access to everyone. No extra + information is required. + """ + GRANTEE_TYPE_UNSPECIFIED = 0 + USER = 1 + GROUP = 2 + EVERYONE = 3 + + class Role(proto.Enum): + r"""Defines the role granted by this permission. + + Values: + ROLE_UNSPECIFIED (0): + The default value. This value is unused. + OWNER (1): + Owner can use, update, share and delete the + resource. + WRITER (2): + Writer can use, update and share the + resource. + READER (3): + Reader can use the resource. + """ + ROLE_UNSPECIFIED = 0 + OWNER = 1 + WRITER = 2 + READER = 3 + + name: str = proto.Field( + proto.STRING, + number=1, + ) + grantee_type: GranteeType = proto.Field( + proto.ENUM, + number=2, + optional=True, + enum=GranteeType, + ) + email_address: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + role: Role = proto.Field( + proto.ENUM, + number=4, + optional=True, + enum=Role, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission_service.py new file mode 100644 index 000000000000..cb9c76ef3167 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission_service.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.protobuf import field_mask_pb2 # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'CreatePermissionRequest', + 'GetPermissionRequest', + 'ListPermissionsRequest', + 'ListPermissionsResponse', + 'UpdatePermissionRequest', + 'DeletePermissionRequest', + 'TransferOwnershipRequest', + 'TransferOwnershipResponse', + }, +) + + +class CreatePermissionRequest(proto.Message): + r"""Request to create a ``Permission``. + + Attributes: + parent (str): + Required. The parent resource of the ``Permission``. Format: + tunedModels/{tuned_model} + permission (google.ai.generativelanguage_v1beta3.types.Permission): + Required. The permission to create. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + permission: gag_permission.Permission = proto.Field( + proto.MESSAGE, + number=2, + message=gag_permission.Permission, + ) + + +class GetPermissionRequest(proto.Message): + r"""Request for getting information about a specific ``Permission``. + + Attributes: + name (str): + Required. The resource name of the permission. + + Format: + ``tunedModels/{tuned_model}permissions/{permission}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListPermissionsRequest(proto.Message): + r"""Request for listing permissions. + + Attributes: + parent (str): + Required. The parent resource of the permissions. Format: + tunedModels/{tuned_model} + page_size (int): + Optional. The maximum number of ``Permission``\ s to return + (per page). The service may return fewer permissions. + + If unspecified, at most 10 permissions will be returned. + This method returns at most 1000 permissions per page, even + if you pass larger page_size. + page_token (str): + Optional. A page token, received from a previous + ``ListPermissions`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListPermissions`` must match the call that provided the + page token. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + page_size: int = proto.Field( + proto.INT32, + number=2, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListPermissionsResponse(proto.Message): + r"""Response from ``ListPermissions`` containing a paginated list of + permissions. + + Attributes: + permissions (MutableSequence[google.ai.generativelanguage_v1beta3.types.Permission]): + Returned permissions. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. + + If this field is omitted, there are no more pages. + """ + + @property + def raw_page(self): + return self + + permissions: MutableSequence[gag_permission.Permission] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_permission.Permission, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class UpdatePermissionRequest(proto.Message): + r"""Request to update the ``Permission``. + + Attributes: + permission (google.ai.generativelanguage_v1beta3.types.Permission): + Required. The permission to update. + + The permission's ``name`` field is used to identify the + permission to update. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Accepted ones: + + - role (``Permission.role`` field) + """ + + permission: gag_permission.Permission = proto.Field( + proto.MESSAGE, + number=1, + message=gag_permission.Permission, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class DeletePermissionRequest(proto.Message): + r"""Request to delete the ``Permission``. + + Attributes: + name (str): + Required. The resource name of the permission. Format: + ``tunedModels/{tuned_model}/permissions/{permission}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class TransferOwnershipRequest(proto.Message): + r"""Request to transfer the ownership of the tuned model. + + Attributes: + name (str): + Required. The resource name of the tuned model to transfer + ownership . + + Format: ``tunedModels/my-model-id`` + email_address (str): + Required. The email address of the user to + whom the tuned model is being transferred to. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + email_address: str = proto.Field( + proto.STRING, + number=2, + ) + + +class TransferOwnershipResponse(proto.Message): + r"""Response from ``TransferOwnership``. + """ + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/safety.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/safety.py new file mode 100644 index 000000000000..f33e790f3577 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/safety.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'HarmCategory', + 'ContentFilter', + 'SafetyFeedback', + 'SafetyRating', + 'SafetySetting', + }, +) + + +class HarmCategory(proto.Enum): + r"""The category of a rating. + + These categories cover various kinds of harms that developers + may wish to adjust. + + Values: + HARM_CATEGORY_UNSPECIFIED (0): + Category is unspecified. + HARM_CATEGORY_DEROGATORY (1): + Negative or harmful comments targeting + identity and/or protected attribute. + HARM_CATEGORY_TOXICITY (2): + Content that is rude, disrepspectful, or + profane. + HARM_CATEGORY_VIOLENCE (3): + Describes scenarios depictng violence against + an individual or group, or general descriptions + of gore. + HARM_CATEGORY_SEXUAL (4): + Contains references to sexual acts or other + lewd content. + HARM_CATEGORY_MEDICAL (5): + Promotes unchecked medical advice. + HARM_CATEGORY_DANGEROUS (6): + Dangerous content that promotes, facilitates, + or encourages harmful acts. + """ + HARM_CATEGORY_UNSPECIFIED = 0 + HARM_CATEGORY_DEROGATORY = 1 + HARM_CATEGORY_TOXICITY = 2 + HARM_CATEGORY_VIOLENCE = 3 + HARM_CATEGORY_SEXUAL = 4 + HARM_CATEGORY_MEDICAL = 5 + HARM_CATEGORY_DANGEROUS = 6 + + +class ContentFilter(proto.Message): + r"""Content filtering metadata associated with processing a + single request. + ContentFilter contains a reason and an optional supporting + string. The reason may be unspecified. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + reason (google.ai.generativelanguage_v1beta3.types.ContentFilter.BlockedReason): + The reason content was blocked during request + processing. + message (str): + A string that describes the filtering + behavior in more detail. + + This field is a member of `oneof`_ ``_message``. + """ + class BlockedReason(proto.Enum): + r"""A list of reasons why content may have been blocked. + + Values: + BLOCKED_REASON_UNSPECIFIED (0): + A blocked reason was not specified. + SAFETY (1): + Content was blocked by safety settings. + OTHER (2): + Content was blocked, but the reason is + uncategorized. + """ + BLOCKED_REASON_UNSPECIFIED = 0 + SAFETY = 1 + OTHER = 2 + + reason: BlockedReason = proto.Field( + proto.ENUM, + number=1, + enum=BlockedReason, + ) + message: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class SafetyFeedback(proto.Message): + r"""Safety feedback for an entire request. + + This field is populated if content in the input and/or response + is blocked due to safety settings. SafetyFeedback may not exist + for every HarmCategory. Each SafetyFeedback will return the + safety settings used by the request as well as the lowest + HarmProbability that should be allowed in order to return a + result. + + Attributes: + rating (google.ai.generativelanguage_v1beta3.types.SafetyRating): + Safety rating evaluated from content. + setting (google.ai.generativelanguage_v1beta3.types.SafetySetting): + Safety settings applied to the request. + """ + + rating: 'SafetyRating' = proto.Field( + proto.MESSAGE, + number=1, + message='SafetyRating', + ) + setting: 'SafetySetting' = proto.Field( + proto.MESSAGE, + number=2, + message='SafetySetting', + ) + + +class SafetyRating(proto.Message): + r"""Safety rating for a piece of content. + + The safety rating contains the category of harm and the harm + probability level in that category for a piece of content. + Content is classified for safety across a number of harm + categories and the probability of the harm classification is + included here. + + Attributes: + category (google.ai.generativelanguage_v1beta3.types.HarmCategory): + Required. The category for this rating. + probability (google.ai.generativelanguage_v1beta3.types.SafetyRating.HarmProbability): + Required. The probability of harm for this + content. + """ + class HarmProbability(proto.Enum): + r"""The probability that a piece of content is harmful. + + The classification system gives the probability of the content + being unsafe. This does not indicate the severity of harm for a + piece of content. + + Values: + HARM_PROBABILITY_UNSPECIFIED (0): + Probability is unspecified. + NEGLIGIBLE (1): + Content has a negligible chance of being + unsafe. + LOW (2): + Content has a low chance of being unsafe. + MEDIUM (3): + Content has a medium chance of being unsafe. + HIGH (4): + Content has a high chance of being unsafe. + """ + HARM_PROBABILITY_UNSPECIFIED = 0 + NEGLIGIBLE = 1 + LOW = 2 + MEDIUM = 3 + HIGH = 4 + + category: 'HarmCategory' = proto.Field( + proto.ENUM, + number=3, + enum='HarmCategory', + ) + probability: HarmProbability = proto.Field( + proto.ENUM, + number=4, + enum=HarmProbability, + ) + + +class SafetySetting(proto.Message): + r"""Safety setting, affecting the safety-blocking behavior. + + Passing a safety setting for a category changes the allowed + proability that content is blocked. + + Attributes: + category (google.ai.generativelanguage_v1beta3.types.HarmCategory): + Required. The category for this setting. + threshold (google.ai.generativelanguage_v1beta3.types.SafetySetting.HarmBlockThreshold): + Required. Controls the probability threshold + at which harm is blocked. + """ + class HarmBlockThreshold(proto.Enum): + r"""Block at and beyond a specified harm probability. + + Values: + HARM_BLOCK_THRESHOLD_UNSPECIFIED (0): + Threshold is unspecified. + BLOCK_LOW_AND_ABOVE (1): + Content with NEGLIGIBLE will be allowed. + BLOCK_MEDIUM_AND_ABOVE (2): + Content with NEGLIGIBLE and LOW will be + allowed. + BLOCK_ONLY_HIGH (3): + Content with NEGLIGIBLE, LOW, and MEDIUM will + be allowed. + BLOCK_NONE (4): + All content will be allowed. + """ + HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0 + BLOCK_LOW_AND_ABOVE = 1 + BLOCK_MEDIUM_AND_ABOVE = 2 + BLOCK_ONLY_HIGH = 3 + BLOCK_NONE = 4 + + category: 'HarmCategory' = proto.Field( + proto.ENUM, + number=3, + enum='HarmCategory', + ) + threshold: HarmBlockThreshold = proto.Field( + proto.ENUM, + number=4, + enum=HarmBlockThreshold, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/text_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/text_service.py new file mode 100644 index 000000000000..d347de6f0728 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/text_service.py @@ -0,0 +1,431 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1beta3.types import citation +from google.ai.generativelanguage_v1beta3.types import safety + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'GenerateTextRequest', + 'GenerateTextResponse', + 'TextPrompt', + 'TextCompletion', + 'EmbedTextRequest', + 'EmbedTextResponse', + 'BatchEmbedTextRequest', + 'BatchEmbedTextResponse', + 'Embedding', + 'CountTextTokensRequest', + 'CountTextTokensResponse', + }, +) + + +class GenerateTextRequest(proto.Message): + r"""Request to generate a text completion response from the + model. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The name of the ``Model`` or ``TunedModel`` to use + for generating the completion. Examples: + models/text-bison-001 tunedModels/sentence-translator-u3b7m + prompt (google.ai.generativelanguage_v1beta3.types.TextPrompt): + Required. The free-form input text given to + the model as a prompt. + Given a prompt, the model will generate a + TextCompletion response it predicts as the + completion of the input text. + temperature (float): + Optional. Controls the randomness of the output. Note: The + default value varies by model, see the ``Model.temperature`` + attribute of the ``Model`` returned the ``getModel`` + function. + + Values can range from [0.0,1.0], inclusive. A value closer + to 1.0 will produce responses that are more varied and + creative, while a value closer to 0.0 will typically result + in more straightforward responses from the model. + + This field is a member of `oneof`_ ``_temperature``. + candidate_count (int): + Optional. Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, this + will default to 1. + + This field is a member of `oneof`_ ``_candidate_count``. + max_output_tokens (int): + Optional. The maximum number of tokens to include in a + candidate. + + If unset, this will default to output_token_limit specified + in the ``Model`` specification. + + This field is a member of `oneof`_ ``_max_output_tokens``. + top_p (float): + Optional. The maximum cumulative probability of tokens to + consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities so + that only the most likely tokens are considered. Top-k + sampling directly limits the maximum number of tokens to + consider, while Nucleus sampling limits number of tokens + based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This field is a member of `oneof`_ ``_top_k``. + safety_settings (MutableSequence[google.ai.generativelanguage_v1beta3.types.SafetySetting]): + A list of unique ``SafetySetting`` instances for blocking + unsafe content. + + that will be enforced on the ``GenerateTextRequest.prompt`` + and ``GenerateTextResponse.candidates``. There should not be + more than one setting for each ``SafetyCategory`` type. The + API will block any prompts and responses that fail to meet + the thresholds set by these settings. This list overrides + the default settings for each ``SafetyCategory`` specified + in the safety_settings. If there is no ``SafetySetting`` for + a given ``SafetyCategory`` provided in the list, the API + will use the default safety setting for that category. + stop_sequences (MutableSequence[str]): + The set of character sequences (up to 5) that + will stop output generation. If specified, the + API will stop at the first appearance of a stop + sequence. The stop sequence will not be included + as part of the response. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: 'TextPrompt' = proto.Field( + proto.MESSAGE, + number=2, + message='TextPrompt', + ) + temperature: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + candidate_count: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + max_output_tokens: int = proto.Field( + proto.INT32, + number=5, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=6, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=7, + optional=True, + ) + safety_settings: MutableSequence[safety.SafetySetting] = proto.RepeatedField( + proto.MESSAGE, + number=8, + message=safety.SafetySetting, + ) + stop_sequences: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=9, + ) + + +class GenerateTextResponse(proto.Message): + r"""The response from the model, including candidate completions. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1beta3.types.TextCompletion]): + Candidate responses from the model. + filters (MutableSequence[google.ai.generativelanguage_v1beta3.types.ContentFilter]): + A set of content filtering metadata for the prompt and + response text. + + This indicates which ``SafetyCategory``\ (s) blocked a + candidate from this response, the lowest ``HarmProbability`` + that triggered a block, and the HarmThreshold setting for + that category. This indicates the smallest change to the + ``SafetySettings`` that would be necessary to unblock at + least 1 response. + + The blocking is configured by the ``SafetySettings`` in the + request (or the default ``SafetySettings`` of the API). + safety_feedback (MutableSequence[google.ai.generativelanguage_v1beta3.types.SafetyFeedback]): + Returns any safety feedback related to + content filtering. + """ + + candidates: MutableSequence['TextCompletion'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='TextCompletion', + ) + filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.ContentFilter, + ) + safety_feedback: MutableSequence[safety.SafetyFeedback] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=safety.SafetyFeedback, + ) + + +class TextPrompt(proto.Message): + r"""Text given to the model as a prompt. + + The Model will use this TextPrompt to Generate a text + completion. + + Attributes: + text (str): + Required. The prompt text. + """ + + text: str = proto.Field( + proto.STRING, + number=1, + ) + + +class TextCompletion(proto.Message): + r"""Output text returned from a model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + output (str): + Output only. The generated text returned from + the model. + safety_ratings (MutableSequence[google.ai.generativelanguage_v1beta3.types.SafetyRating]): + Ratings for the safety of a response. + + There is at most one rating per category. + citation_metadata (google.ai.generativelanguage_v1beta3.types.CitationMetadata): + Output only. Citation information for model-generated + ``output`` in this ``TextCompletion``. + + This field may be populated with attribution information for + any text included in the ``output``. + + This field is a member of `oneof`_ ``_citation_metadata``. + """ + + output: str = proto.Field( + proto.STRING, + number=1, + ) + safety_ratings: MutableSequence[safety.SafetyRating] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=safety.SafetyRating, + ) + citation_metadata: citation.CitationMetadata = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=citation.CitationMetadata, + ) + + +class EmbedTextRequest(proto.Message): + r"""Request to get a text embedding from the model. + + Attributes: + model (str): + Required. The model name to use with the + format model=models/{model}. + text (str): + Required. The free-form input text that the + model will turn into an embedding. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + text: str = proto.Field( + proto.STRING, + number=2, + ) + + +class EmbedTextResponse(proto.Message): + r"""The response to a EmbedTextRequest. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + embedding (google.ai.generativelanguage_v1beta3.types.Embedding): + Output only. The embedding generated from the + input text. + + This field is a member of `oneof`_ ``_embedding``. + """ + + embedding: 'Embedding' = proto.Field( + proto.MESSAGE, + number=1, + optional=True, + message='Embedding', + ) + + +class BatchEmbedTextRequest(proto.Message): + r"""Batch request to get a text embedding from the model. + + Attributes: + model (str): + Required. The name of the ``Model`` to use for generating + the embedding. Examples: models/embedding-gecko-001 + texts (MutableSequence[str]): + Required. The free-form input texts that the + model will turn into an embedding. The current + limit is 100 texts, over which an error will be + thrown. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + texts: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=2, + ) + + +class BatchEmbedTextResponse(proto.Message): + r"""The response to a EmbedTextRequest. + + Attributes: + embeddings (MutableSequence[google.ai.generativelanguage_v1beta3.types.Embedding]): + Output only. The embeddings generated from + the input text. + """ + + embeddings: MutableSequence['Embedding'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='Embedding', + ) + + +class Embedding(proto.Message): + r"""A list of floats representing the embedding. + + Attributes: + value (MutableSequence[float]): + The embedding values. + """ + + value: MutableSequence[float] = proto.RepeatedField( + proto.FLOAT, + number=1, + ) + + +class CountTextTokensRequest(proto.Message): + r"""Counts the number of tokens in the ``prompt`` sent to a model. + + Models may tokenize text differently, so each model may return a + different ``token_count``. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + prompt (google.ai.generativelanguage_v1beta3.types.TextPrompt): + Required. The free-form input text given to + the model as a prompt. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: 'TextPrompt' = proto.Field( + proto.MESSAGE, + number=2, + message='TextPrompt', + ) + + +class CountTextTokensResponse(proto.Message): + r"""A response from ``CountTextTokens``. + + It returns the model's ``token_count`` for the ``prompt``. + + Attributes: + token_count (int): + The number of tokens that the ``model`` tokenizes the + ``prompt`` into. + + Always non-negative. + """ + + token_count: int = proto.Field( + proto.INT32, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/tuned_model.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/tuned_model.py new file mode 100644 index 000000000000..5fb0a44053d4 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/tuned_model.py @@ -0,0 +1,414 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.protobuf import timestamp_pb2 # type: ignore + + +__protobuf__ = proto.module( + package='google.ai.generativelanguage.v1beta3', + manifest={ + 'TunedModel', + 'TunedModelSource', + 'TuningTask', + 'Hyperparameters', + 'Dataset', + 'TuningExamples', + 'TuningExample', + 'TuningSnapshot', + }, +) + + +class TunedModel(proto.Message): + r"""A fine-tuned model created using + ModelService.CreateTunedModel. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + tuned_model_source (google.ai.generativelanguage_v1beta3.types.TunedModelSource): + Optional. TunedModel to use as the starting + point for training the new model. + + This field is a member of `oneof`_ ``source_model``. + base_model (str): + Immutable. The name of the ``Model`` to tune. Example: + ``models/text-bison-001`` + + This field is a member of `oneof`_ ``source_model``. + name (str): + Output only. The tuned model name. A unique name will be + generated on create. Example: ``tunedModels/az2mb0bpw6i`` If + display_name is set on create, the id portion of the name + will be set by concatenating the words of the display_name + with hyphens and adding a random portion for uniqueness. + Example: display_name = "Sentence Translator" name = + "tunedModels/sentence-translator-u3b7m". + display_name (str): + Optional. The name to display for this model + in user interfaces. The display name must be up + to 40 characters including spaces. + description (str): + Optional. A short description of this model. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This value specifies default to be the one used by the base + model while creating the model. + + This field is a member of `oneof`_ ``_temperature``. + top_p (float): + Optional. For Nucleus sampling. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. + + This value specifies default to be the one used by the base + model while creating the model. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. For Top-k sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. This value specifies default to be used by the + backend while making the call to the model. + + This value specifies default to be the one used by the base + model while creating the model. + + This field is a member of `oneof`_ ``_top_k``. + state (google.ai.generativelanguage_v1beta3.types.TunedModel.State): + Output only. The state of the tuned model. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when this model + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when this model + was updated. + tuning_task (google.ai.generativelanguage_v1beta3.types.TuningTask): + Required. The tuning task that creates the + tuned model. + """ + class State(proto.Enum): + r"""The state of the tuned model. + + Values: + STATE_UNSPECIFIED (0): + The default value. This value is unused. + CREATING (1): + The model is being created. + ACTIVE (2): + The model is ready to be used. + FAILED (3): + The model failed to be created. + """ + STATE_UNSPECIFIED = 0 + CREATING = 1 + ACTIVE = 2 + FAILED = 3 + + tuned_model_source: 'TunedModelSource' = proto.Field( + proto.MESSAGE, + number=3, + oneof='source_model', + message='TunedModelSource', + ) + base_model: str = proto.Field( + proto.STRING, + number=4, + oneof='source_model', + ) + name: str = proto.Field( + proto.STRING, + number=1, + ) + display_name: str = proto.Field( + proto.STRING, + number=5, + ) + description: str = proto.Field( + proto.STRING, + number=6, + ) + temperature: float = proto.Field( + proto.FLOAT, + number=11, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=12, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=13, + optional=True, + ) + state: State = proto.Field( + proto.ENUM, + number=7, + enum=State, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=8, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=9, + message=timestamp_pb2.Timestamp, + ) + tuning_task: 'TuningTask' = proto.Field( + proto.MESSAGE, + number=10, + message='TuningTask', + ) + + +class TunedModelSource(proto.Message): + r"""Tuned model as a source for training a new model. + + Attributes: + tuned_model (str): + Immutable. The name of the ``TunedModel`` to use as the + starting point for training the new model. Example: + ``tunedModels/my-tuned-model`` + base_model (str): + Output only. The name of the base ``Model`` this + ``TunedModel`` was tuned from. Example: + ``models/text-bison-001`` + """ + + tuned_model: str = proto.Field( + proto.STRING, + number=1, + ) + base_model: str = proto.Field( + proto.STRING, + number=2, + ) + + +class TuningTask(proto.Message): + r"""Tuning tasks that create tuned models. + + Attributes: + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when tuning this + model started. + complete_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when tuning this + model completed. + snapshots (MutableSequence[google.ai.generativelanguage_v1beta3.types.TuningSnapshot]): + Output only. Metrics collected during tuning. + training_data (google.ai.generativelanguage_v1beta3.types.Dataset): + Required. Input only. Immutable. The model + training data. + hyperparameters (google.ai.generativelanguage_v1beta3.types.Hyperparameters): + Immutable. Hyperparameters controlling the + tuning process. If not provided, default values + will be used. + """ + + start_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=1, + message=timestamp_pb2.Timestamp, + ) + complete_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=2, + message=timestamp_pb2.Timestamp, + ) + snapshots: MutableSequence['TuningSnapshot'] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message='TuningSnapshot', + ) + training_data: 'Dataset' = proto.Field( + proto.MESSAGE, + number=4, + message='Dataset', + ) + hyperparameters: 'Hyperparameters' = proto.Field( + proto.MESSAGE, + number=5, + message='Hyperparameters', + ) + + +class Hyperparameters(proto.Message): + r"""Hyperparameters controlling the tuning process. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + epoch_count (int): + Immutable. The number of training epochs. An + epoch is one pass through the training data. If + not set, a default of 10 will be used. + + This field is a member of `oneof`_ ``_epoch_count``. + batch_size (int): + Immutable. The batch size hyperparameter for + tuning. If not set, a default of 16 or 64 will + be used based on the number of training + examples. + + This field is a member of `oneof`_ ``_batch_size``. + learning_rate (float): + Immutable. The learning rate hyperparameter + for tuning. If not set, a default of 0.0002 or + 0.002 will be calculated based on the number of + training examples. + + This field is a member of `oneof`_ ``_learning_rate``. + """ + + epoch_count: int = proto.Field( + proto.INT32, + number=14, + optional=True, + ) + batch_size: int = proto.Field( + proto.INT32, + number=15, + optional=True, + ) + learning_rate: float = proto.Field( + proto.FLOAT, + number=16, + optional=True, + ) + + +class Dataset(proto.Message): + r"""Dataset for training or validation. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + examples (google.ai.generativelanguage_v1beta3.types.TuningExamples): + Optional. Inline examples. + + This field is a member of `oneof`_ ``dataset``. + """ + + examples: 'TuningExamples' = proto.Field( + proto.MESSAGE, + number=1, + oneof='dataset', + message='TuningExamples', + ) + + +class TuningExamples(proto.Message): + r"""A set of tuning examples. Can be training or validatation + data. + + Attributes: + examples (MutableSequence[google.ai.generativelanguage_v1beta3.types.TuningExample]): + Required. The examples. Example input can be + for text or discuss, but all examples in a set + must be of the same type. + """ + + examples: MutableSequence['TuningExample'] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message='TuningExample', + ) + + +class TuningExample(proto.Message): + r"""A single example for tuning. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + text_input (str): + Optional. Text model input. + + This field is a member of `oneof`_ ``model_input``. + output (str): + Required. The expected model output. + """ + + text_input: str = proto.Field( + proto.STRING, + number=1, + oneof='model_input', + ) + output: str = proto.Field( + proto.STRING, + number=3, + ) + + +class TuningSnapshot(proto.Message): + r"""Record for a single tuning step. + + Attributes: + step (int): + Output only. The tuning step. + epoch (int): + Output only. The epoch this step was part of. + mean_loss (float): + Output only. The mean loss of the training + examples for this step. + compute_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when this metric + was computed. + """ + + step: int = proto.Field( + proto.INT32, + number=1, + ) + epoch: int = proto.Field( + proto.INT32, + number=2, + ) + mean_loss: float = proto.Field( + proto.FLOAT, + number=3, + ) + compute_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=4, + message=timestamp_pb2.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini b/owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini new file mode 100644 index 000000000000..574c5aed394b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +python_version = 3.7 +namespace_packages = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py new file mode 100644 index 000000000000..66bac3a254b7 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pathlib +import shutil +import subprocess +import sys + + +import nox # type: ignore + +ALL_PYTHON = [ + "3.7", + "3.8", + "3.9", + "3.10", + "3.11", +] + +CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() + +LOWER_BOUND_CONSTRAINTS_FILE = CURRENT_DIRECTORY / "constraints.txt" +PACKAGE_NAME = subprocess.check_output([sys.executable, "setup.py", "--name"], encoding="utf-8") + +BLACK_VERSION = "black==22.3.0" +BLACK_PATHS = ["docs", "google", "tests", "samples", "noxfile.py", "setup.py"] +DEFAULT_PYTHON_VERSION = "3.11" + +nox.sessions = [ + "unit", + "cover", + "mypy", + "check_lower_bounds" + # exclude update_lower_bounds from default + "docs", + "blacken", + "lint", + "lint_setup_py", +] + +@nox.session(python=ALL_PYTHON) +def unit(session): + """Run the unit test suite.""" + + session.install('coverage', 'pytest', 'pytest-cov', 'pytest-asyncio', 'asyncmock; python_version < "3.8"') + session.install('-e', '.') + + session.run( + 'py.test', + '--quiet', + '--cov=google/ai/generativelanguage_v1beta3/', + '--cov=tests/', + '--cov-config=.coveragerc', + '--cov-report=term', + '--cov-report=html', + os.path.join('tests', 'unit', ''.join(session.posargs)) + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def cover(session): + """Run the final coverage report. + This outputs the coverage report aggregating coverage from the unit + test runs (not system test runs), and then erases coverage data. + """ + session.install("coverage", "pytest-cov") + session.run("coverage", "report", "--show-missing", "--fail-under=100") + + session.run("coverage", "erase") + + +@nox.session(python=ALL_PYTHON) +def mypy(session): + """Run the type checker.""" + session.install( + 'mypy', + 'types-requests', + 'types-protobuf' + ) + session.install('.') + session.run( + 'mypy', + '--explicit-package-bases', + 'google', + ) + + +@nox.session +def update_lower_bounds(session): + """Update lower bounds in constraints.txt to match setup.py""" + session.install('google-cloud-testutils') + session.install('.') + + session.run( + 'lower-bound-checker', + 'update', + '--package-name', + PACKAGE_NAME, + '--constraints-file', + str(LOWER_BOUND_CONSTRAINTS_FILE), + ) + + +@nox.session +def check_lower_bounds(session): + """Check lower bounds in setup.py are reflected in constraints file""" + session.install('google-cloud-testutils') + session.install('.') + + session.run( + 'lower-bound-checker', + 'check', + '--package-name', + PACKAGE_NAME, + '--constraints-file', + str(LOWER_BOUND_CONSTRAINTS_FILE), + ) + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def docs(session): + """Build the docs for this library.""" + + session.install("-e", ".") + session.install("sphinx==7.0.1", "alabaster", "recommonmark") + + shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) + session.run( + "sphinx-build", + "-W", # warnings as errors + "-T", # show full traceback on exception + "-N", # no colors + "-b", + "html", + "-d", + os.path.join("docs", "_build", "doctrees", ""), + os.path.join("docs", ""), + os.path.join("docs", "_build", "html", ""), + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint(session): + """Run linters. + + Returns a failure if the linters find linting errors or sufficiently + serious code quality issues. + """ + session.install("flake8", BLACK_VERSION) + session.run( + "black", + "--check", + *BLACK_PATHS, + ) + session.run("flake8", "google", "tests", "samples") + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def blacken(session): + """Run black. Format code to uniform standard.""" + session.install(BLACK_VERSION) + session.run( + "black", + *BLACK_PATHS, + ) + + +@nox.session(python=DEFAULT_PYTHON_VERSION) +def lint_setup_py(session): + """Verify that setup.py is valid (including RST check).""" + session.install("docutils", "pygments") + session.run("python", "setup.py", "check", "--restructuredtext", "--strict") diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py new file mode 100644 index 000000000000..47654831cf34 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountMessageTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_DiscussService_CountMessageTokens_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_message_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_DiscussService_CountMessageTokens_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py new file mode 100644 index 000000000000..707cb2f20f88 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountMessageTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_DiscussService_CountMessageTokens_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_message_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_DiscussService_CountMessageTokens_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py new file mode 100644 index 000000000000..591fc46d0599 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateMessage +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_DiscussService_GenerateMessage_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_message(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_DiscussService_GenerateMessage_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py new file mode 100644 index 000000000000..773f8d2bee4c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateMessage +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_DiscussService_GenerateMessage_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_generate_message(): + # Create a client + client = generativelanguage_v1beta3.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1beta3.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_message(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_DiscussService_GenerateMessage_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py new file mode 100644 index 000000000000..df994db79e4b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_CreateTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.CreateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_CreateTunedModel_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py new file mode 100644 index 000000000000..566c36de64f2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_CreateTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.CreateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_CreateTunedModel_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py new file mode 100644 index 000000000000..3dccf41cf4d4 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_DeleteTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + await client.delete_tuned_model(request=request) + + +# [END generativelanguage_v1beta3_generated_ModelService_DeleteTunedModel_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py new file mode 100644 index 000000000000..50ed9788296b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_DeleteTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + client.delete_tuned_model(request=request) + + +# [END generativelanguage_v1beta3_generated_ModelService_DeleteTunedModel_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py new file mode 100644 index 000000000000..64c9bb57c1f8 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_GetModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_get_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_GetModel_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py new file mode 100644 index 000000000000..030d00d31381 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_GetModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_get_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_GetModel_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py new file mode 100644 index 000000000000..392a048cfedc --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_GetTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_GetTunedModel_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py new file mode 100644 index 000000000000..cc83b550ac1e --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_GetTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_GetTunedModel_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py new file mode 100644 index 000000000000..d60fa3c94241 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_ListModels_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_list_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_ListModels_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py new file mode 100644 index 000000000000..9dbd520baa96 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_ListModels_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_list_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_ListModels_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py new file mode 100644 index 000000000000..e0a69c6d173f --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListTunedModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_ListTunedModels_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_ListTunedModels_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py new file mode 100644 index 000000000000..bcc57481ac4d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListTunedModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_ListTunedModels_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_ListTunedModels_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py new file mode 100644 index 000000000000..ede952974871 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_UpdateTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceAsyncClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.UpdateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + response = await client.update_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_UpdateTunedModel_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py new file mode 100644 index 000000000000..d8ab4184499c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_ModelService_UpdateTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1beta3.ModelServiceClient() + + # Initialize request argument(s) + tuned_model = generativelanguage_v1beta3.TunedModel() + tuned_model.tuning_task.training_data.examples.examples.text_input = "text_input_value" + tuned_model.tuning_task.training_data.examples.examples.output = "output_value" + + request = generativelanguage_v1beta3.UpdateTunedModelRequest( + tuned_model=tuned_model, + ) + + # Make the request + response = client.update_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_ModelService_UpdateTunedModel_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py new file mode 100644 index 000000000000..fb1a00dacb55 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_CreatePermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_create_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = await client.create_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_CreatePermission_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py new file mode 100644 index 000000000000..5280196dd234 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_CreatePermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_create_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = client.create_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_CreatePermission_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py new file mode 100644 index 000000000000..3d85c882f03b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeletePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_DeletePermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_delete_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + await client.delete_permission(request=request) + + +# [END generativelanguage_v1beta3_generated_PermissionService_DeletePermission_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py new file mode 100644 index 000000000000..f8cd0d4e0e11 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeletePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_DeletePermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_delete_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + client.delete_permission(request=request) + + +# [END generativelanguage_v1beta3_generated_PermissionService_DeletePermission_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py new file mode 100644 index 000000000000..4a389843feb5 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetPermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_GetPermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_get_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = await client.get_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_GetPermission_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py new file mode 100644 index 000000000000..1140e10fadd0 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetPermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_GetPermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_get_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = client.get_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_GetPermission_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py new file mode 100644 index 000000000000..cf5fbb6b95f1 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListPermissions +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_ListPermissions_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_list_permissions(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_ListPermissions_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py new file mode 100644 index 000000000000..74c4623e6cea --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListPermissions +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_ListPermissions_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_list_permissions(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_ListPermissions_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py new file mode 100644 index 000000000000..2d4597d0175c --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for TransferOwnership +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_TransferOwnership_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = await client.transfer_ownership(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_TransferOwnership_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py new file mode 100644 index 000000000000..85ccdf88ccb1 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for TransferOwnership +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_TransferOwnership_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = client.transfer_ownership(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_TransferOwnership_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py new file mode 100644 index 000000000000..a0929b4b0b89 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_UpdatePermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_update_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.UpdatePermissionRequest( + ) + + # Make the request + response = await client.update_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_UpdatePermission_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py new file mode 100644 index 000000000000..23703b923e3b --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_PermissionService_UpdatePermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_update_permission(): + # Create a client + client = generativelanguage_v1beta3.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.UpdatePermissionRequest( + ) + + # Make the request + response = client.update_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_PermissionService_UpdatePermission_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py new file mode 100644 index 000000000000..9dd26686a1cd --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchEmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_BatchEmbedText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.BatchEmbedTextRequest( + model="model_value", + texts=['texts_value1', 'texts_value2'], + ) + + # Make the request + response = await client.batch_embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_BatchEmbedText_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py new file mode 100644 index 000000000000..bf55a6889d17 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchEmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_BatchEmbedText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.BatchEmbedTextRequest( + model="model_value", + texts=['texts_value1', 'texts_value2'], + ) + + # Make the request + response = client.batch_embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_BatchEmbedText_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py new file mode 100644 index 000000000000..2e57c81f635a --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountTextTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_CountTextTokens_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_text_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_CountTextTokens_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py new file mode 100644 index 000000000000..8925b43579ff --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountTextTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_CountTextTokens_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_text_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_CountTextTokens_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py new file mode 100644 index 000000000000..d4785b5d2ca6 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_EmbedText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = await client.embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_EmbedText_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py new file mode 100644 index 000000000000..90a4037149f1 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_EmbedText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_embed_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1beta3.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Make the request + response = client.embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_EmbedText_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py new file mode 100644 index 000000000000..9ded1d48937a --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_GenerateText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +async def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_GenerateText_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py new file mode 100644 index 000000000000..a6c7dfe835e2 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1beta3_generated_TextService_GenerateText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1beta3 + + +def sample_generate_text(): + # Create a client + client = generativelanguage_v1beta3.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1beta3.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1beta3.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1beta3_generated_TextService_GenerateText_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json new file mode 100644 index 000000000000..91de9e353f90 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json @@ -0,0 +1,3222 @@ +{ + "clientLibrary": { + "apis": [ + { + "id": "google.ai.generativelanguage.v1beta3", + "version": "v1beta3" + } + ], + "language": "PYTHON", + "name": "google-ai-generativelanguage", + "version": "0.1.0" + }, + "snippets": [ + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceAsyncClient", + "shortName": "DiscussServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceAsyncClient.count_message_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService.CountMessageTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "CountMessageTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CountMessageTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.MessagePrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.CountMessageTokensResponse", + "shortName": "count_message_tokens" + }, + "description": "Sample for CountMessageTokens", + "file": "generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_DiscussService_CountMessageTokens_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceClient", + "shortName": "DiscussServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceClient.count_message_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService.CountMessageTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "CountMessageTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CountMessageTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.MessagePrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.CountMessageTokensResponse", + "shortName": "count_message_tokens" + }, + "description": "Sample for CountMessageTokens", + "file": "generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_DiscussService_CountMessageTokens_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceAsyncClient", + "shortName": "DiscussServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceAsyncClient.generate_message", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService.GenerateMessage", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "GenerateMessage" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GenerateMessageRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.MessagePrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.GenerateMessageResponse", + "shortName": "generate_message" + }, + "description": "Sample for GenerateMessage", + "file": "generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_DiscussService_GenerateMessage_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceClient", + "shortName": "DiscussServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.DiscussServiceClient.generate_message", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService.GenerateMessage", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "GenerateMessage" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GenerateMessageRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.MessagePrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.GenerateMessageResponse", + "shortName": "generate_message" + }, + "description": "Sample for GenerateMessage", + "file": "generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_DiscussService_GenerateMessage_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient.create_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.CreateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "CreateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CreateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1beta3.types.TunedModel" + }, + { + "name": "tuned_model_id", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "create_tuned_model" + }, + "description": "Sample for CreateTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_CreateTunedModel_async", + "segments": [ + { + "end": 59, + "start": 27, + "type": "FULL" + }, + { + "end": 59, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 56, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 60, + "start": 57, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient.create_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.CreateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "CreateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CreateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1beta3.types.TunedModel" + }, + { + "name": "tuned_model_id", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "create_tuned_model" + }, + "description": "Sample for CreateTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_CreateTunedModel_sync", + "segments": [ + { + "end": 59, + "start": 27, + "type": "FULL" + }, + { + "end": 59, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 56, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 60, + "start": 57, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient.delete_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.DeleteTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "DeleteTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.DeleteTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "delete_tuned_model" + }, + "description": "Sample for DeleteTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_DeleteTunedModel_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient.delete_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.DeleteTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "DeleteTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.DeleteTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "delete_tuned_model" + }, + "description": "Sample for DeleteTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_DeleteTunedModel_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient.get_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.GetModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GetModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Model", + "shortName": "get_model" + }, + "description": "Sample for GetModel", + "file": "generativelanguage_v1beta3_generated_model_service_get_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_GetModel_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_get_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient.get_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.GetModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GetModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Model", + "shortName": "get_model" + }, + "description": "Sample for GetModel", + "file": "generativelanguage_v1beta3_generated_model_service_get_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_GetModel_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_get_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient.get_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.GetTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GetTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.TunedModel", + "shortName": "get_tuned_model" + }, + "description": "Sample for GetTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_GetTunedModel_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient.get_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.GetTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GetTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.TunedModel", + "shortName": "get_tuned_model" + }, + "description": "Sample for GetTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_GetTunedModel_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient.list_models", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.ListModels", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.ListModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListModelsAsyncPager", + "shortName": "list_models" + }, + "description": "Sample for ListModels", + "file": "generativelanguage_v1beta3_generated_model_service_list_models_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_ListModels_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_list_models_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient.list_models", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.ListModels", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.ListModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListModelsPager", + "shortName": "list_models" + }, + "description": "Sample for ListModels", + "file": "generativelanguage_v1beta3_generated_model_service_list_models_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_ListModels_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_list_models_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient.list_tuned_models", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.ListTunedModels", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListTunedModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.ListTunedModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListTunedModelsAsyncPager", + "shortName": "list_tuned_models" + }, + "description": "Sample for ListTunedModels", + "file": "generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_ListTunedModels_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient.list_tuned_models", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.ListTunedModels", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListTunedModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.ListTunedModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.services.model_service.pagers.ListTunedModelsPager", + "shortName": "list_tuned_models" + }, + "description": "Sample for ListTunedModels", + "file": "generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_ListTunedModels_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceAsyncClient.update_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.UpdateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "UpdateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.UpdateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1beta3.types.TunedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.TunedModel", + "shortName": "update_tuned_model" + }, + "description": "Sample for UpdateTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_UpdateTunedModel_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.ModelServiceClient.update_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService.UpdateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.ModelService", + "shortName": "ModelService" + }, + "shortName": "UpdateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.UpdateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1beta3.types.TunedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.TunedModel", + "shortName": "update_tuned_model" + }, + "description": "Sample for UpdateTunedModel", + "file": "generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_ModelService_UpdateTunedModel_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient.create_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.CreatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "CreatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CreatePermissionRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1beta3.types.Permission" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Permission", + "shortName": "create_permission" + }, + "description": "Sample for CreatePermission", + "file": "generativelanguage_v1beta3_generated_permission_service_create_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_CreatePermission_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_create_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient.create_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.CreatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "CreatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CreatePermissionRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1beta3.types.Permission" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Permission", + "shortName": "create_permission" + }, + "description": "Sample for CreatePermission", + "file": "generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_CreatePermission_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient.delete_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.DeletePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "DeletePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.DeletePermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "delete_permission" + }, + "description": "Sample for DeletePermission", + "file": "generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_DeletePermission_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient.delete_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.DeletePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "DeletePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.DeletePermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "shortName": "delete_permission" + }, + "description": "Sample for DeletePermission", + "file": "generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_DeletePermission_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient.get_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.GetPermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "GetPermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GetPermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Permission", + "shortName": "get_permission" + }, + "description": "Sample for GetPermission", + "file": "generativelanguage_v1beta3_generated_permission_service_get_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_GetPermission_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_get_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient.get_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.GetPermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "GetPermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GetPermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Permission", + "shortName": "get_permission" + }, + "description": "Sample for GetPermission", + "file": "generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_GetPermission_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient.list_permissions", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.ListPermissions", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "ListPermissions" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.ListPermissionsRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.services.permission_service.pagers.ListPermissionsAsyncPager", + "shortName": "list_permissions" + }, + "description": "Sample for ListPermissions", + "file": "generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_ListPermissions_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient.list_permissions", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.ListPermissions", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "ListPermissions" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.ListPermissionsRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.services.permission_service.pagers.ListPermissionsPager", + "shortName": "list_permissions" + }, + "description": "Sample for ListPermissions", + "file": "generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_ListPermissions_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient.transfer_ownership", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.TransferOwnership", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "TransferOwnership" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.TransferOwnershipRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.TransferOwnershipResponse", + "shortName": "transfer_ownership" + }, + "description": "Sample for TransferOwnership", + "file": "generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_TransferOwnership_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient.transfer_ownership", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.TransferOwnership", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "TransferOwnership" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.TransferOwnershipRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.TransferOwnershipResponse", + "shortName": "transfer_ownership" + }, + "description": "Sample for TransferOwnership", + "file": "generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_TransferOwnership_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceAsyncClient.update_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.UpdatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "UpdatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.UpdatePermissionRequest" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1beta3.types.Permission" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Permission", + "shortName": "update_permission" + }, + "description": "Sample for UpdatePermission", + "file": "generativelanguage_v1beta3_generated_permission_service_update_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_UpdatePermission_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_update_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.PermissionServiceClient.update_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService.UpdatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "UpdatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.UpdatePermissionRequest" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1beta3.types.Permission" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.Permission", + "shortName": "update_permission" + }, + "description": "Sample for UpdatePermission", + "file": "generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_PermissionService_UpdatePermission_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient.batch_embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.BatchEmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "BatchEmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.BatchEmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "texts", + "type": "MutableSequence[str]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.BatchEmbedTextResponse", + "shortName": "batch_embed_text" + }, + "description": "Sample for BatchEmbedText", + "file": "generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_BatchEmbedText_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient.batch_embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.BatchEmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "BatchEmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.BatchEmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "texts", + "type": "MutableSequence[str]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.BatchEmbedTextResponse", + "shortName": "batch_embed_text" + }, + "description": "Sample for BatchEmbedText", + "file": "generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_BatchEmbedText_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient.count_text_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.CountTextTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "CountTextTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CountTextTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.TextPrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.CountTextTokensResponse", + "shortName": "count_text_tokens" + }, + "description": "Sample for CountTextTokens", + "file": "generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_CountTextTokens_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient.count_text_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.CountTextTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "CountTextTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.CountTextTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.TextPrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.CountTextTokensResponse", + "shortName": "count_text_tokens" + }, + "description": "Sample for CountTextTokens", + "file": "generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_CountTextTokens_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient.embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.EmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "EmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.EmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "text", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.EmbedTextResponse", + "shortName": "embed_text" + }, + "description": "Sample for EmbedText", + "file": "generativelanguage_v1beta3_generated_text_service_embed_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_EmbedText_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_embed_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient.embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.EmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "EmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.EmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "text", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.EmbedTextResponse", + "shortName": "embed_text" + }, + "description": "Sample for EmbedText", + "file": "generativelanguage_v1beta3_generated_text_service_embed_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_EmbedText_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_embed_text_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceAsyncClient.generate_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.GenerateText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "GenerateText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GenerateTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.TextPrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "max_output_tokens", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.GenerateTextResponse", + "shortName": "generate_text" + }, + "description": "Sample for GenerateText", + "file": "generativelanguage_v1beta3_generated_text_service_generate_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_GenerateText_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_generate_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1beta3.TextServiceClient.generate_text", + "method": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService.GenerateText", + "service": { + "fullName": "google.ai.generativelanguage.v1beta3.TextService", + "shortName": "TextService" + }, + "shortName": "GenerateText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1beta3.types.GenerateTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1beta3.types.TextPrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "max_output_tokens", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, str]" + } + ], + "resultType": "google.ai.generativelanguage_v1beta3.types.GenerateTextResponse", + "shortName": "generate_text" + }, + "description": "Sample for GenerateText", + "file": "generativelanguage_v1beta3_generated_text_service_generate_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1beta3_generated_TextService_GenerateText_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1beta3_generated_text_service_generate_text_sync.py" + } + ] +} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/scripts/fixup_generativelanguage_v1beta3_keywords.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/scripts/fixup_generativelanguage_v1beta3_keywords.py new file mode 100644 index 000000000000..34e59639f6ae --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/scripts/fixup_generativelanguage_v1beta3_keywords.py @@ -0,0 +1,194 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +import os +import libcst as cst +import pathlib +import sys +from typing import (Any, Callable, Dict, List, Sequence, Tuple) + + +def partition( + predicate: Callable[[Any], bool], + iterator: Sequence[Any] +) -> Tuple[List[Any], List[Any]]: + """A stable, out-of-place partition.""" + results = ([], []) + + for i in iterator: + results[int(predicate(i))].append(i) + + # Returns trueList, falseList + return results[1], results[0] + + +class generativelanguageCallTransformer(cst.CSTTransformer): + CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') + METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { + 'batch_embed_text': ('model', 'texts', ), + 'count_message_tokens': ('model', 'prompt', ), + 'count_text_tokens': ('model', 'prompt', ), + 'create_permission': ('parent', 'permission', ), + 'create_tuned_model': ('tuned_model', 'tuned_model_id', ), + 'delete_permission': ('name', ), + 'delete_tuned_model': ('name', ), + 'embed_text': ('model', 'text', ), + 'generate_message': ('model', 'prompt', 'temperature', 'candidate_count', 'top_p', 'top_k', ), + 'generate_text': ('model', 'prompt', 'temperature', 'candidate_count', 'max_output_tokens', 'top_p', 'top_k', 'safety_settings', 'stop_sequences', ), + 'get_model': ('name', ), + 'get_permission': ('name', ), + 'get_tuned_model': ('name', ), + 'list_models': ('page_size', 'page_token', ), + 'list_permissions': ('parent', 'page_size', 'page_token', ), + 'list_tuned_models': ('page_size', 'page_token', ), + 'transfer_ownership': ('name', 'email_address', ), + 'update_permission': ('permission', 'update_mask', ), + 'update_tuned_model': ('tuned_model', 'update_mask', ), + } + + def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: + try: + key = original.func.attr.value + kword_params = self.METHOD_TO_PARAMS[key] + except (AttributeError, KeyError): + # Either not a method from the API or too convoluted to be sure. + return updated + + # If the existing code is valid, keyword args come after positional args. + # Therefore, all positional args must map to the first parameters. + args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) + if any(k.keyword.value == "request" for k in kwargs): + # We've already fixed this file, don't fix it again. + return updated + + kwargs, ctrl_kwargs = partition( + lambda a: a.keyword.value not in self.CTRL_PARAMS, + kwargs + ) + + args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] + ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) + for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) + + request_arg = cst.Arg( + value=cst.Dict([ + cst.DictElement( + cst.SimpleString("'{}'".format(name)), +cst.Element(value=arg.value) + ) + # Note: the args + kwargs looks silly, but keep in mind that + # the control parameters had to be stripped out, and that + # those could have been passed positionally or by keyword. + for name, arg in zip(kword_params, args + kwargs)]), + keyword=cst.Name("request") + ) + + return updated.with_changes( + args=[request_arg] + ctrl_kwargs + ) + + +def fix_files( + in_dir: pathlib.Path, + out_dir: pathlib.Path, + *, + transformer=generativelanguageCallTransformer(), +): + """Duplicate the input dir to the output dir, fixing file method calls. + + Preconditions: + * in_dir is a real directory + * out_dir is a real, empty directory + """ + pyfile_gen = ( + pathlib.Path(os.path.join(root, f)) + for root, _, files in os.walk(in_dir) + for f in files if os.path.splitext(f)[1] == ".py" + ) + + for fpath in pyfile_gen: + with open(fpath, 'r') as f: + src = f.read() + + # Parse the code and insert method call fixes. + tree = cst.parse_module(src) + updated = tree.visit(transformer) + + # Create the path and directory structure for the new file. + updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) + updated_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate the updated source file at the corresponding path. + with open(updated_path, 'w') as f: + f.write(updated.code) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="""Fix up source that uses the generativelanguage client library. + +The existing sources are NOT overwritten but are copied to output_dir with changes made. + +Note: This tool operates at a best-effort level at converting positional + parameters in client method calls to keyword based parameters. + Cases where it WILL FAIL include + A) * or ** expansion in a method call. + B) Calls via function or method alias (includes free function calls) + C) Indirect or dispatched calls (e.g. the method is looked up dynamically) + + These all constitute false negatives. The tool will also detect false + positives when an API method shares a name with another method. +""") + parser.add_argument( + '-d', + '--input-directory', + required=True, + dest='input_dir', + help='the input directory to walk for python files to fix up', + ) + parser.add_argument( + '-o', + '--output-directory', + required=True, + dest='output_dir', + help='the directory to output files fixed via un-flattening', + ) + args = parser.parse_args() + input_dir = pathlib.Path(args.input_dir) + output_dir = pathlib.Path(args.output_dir) + if not input_dir.is_dir(): + print( + f"input directory '{input_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if not output_dir.is_dir(): + print( + f"output directory '{output_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if os.listdir(output_dir): + print( + f"output directory '{output_dir}' is not empty", + file=sys.stderr, + ) + sys.exit(-1) + + fix_files(input_dir, output_dir) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py new file mode 100644 index 000000000000..0e0b1e55d45f --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import io +import os + +import setuptools # type: ignore + +package_root = os.path.abspath(os.path.dirname(__file__)) + +name = 'google-ai-generativelanguage' + + +description = "Google Ai Generativelanguage API client library" + +version = {} +with open(os.path.join(package_root, 'google/ai/generativelanguage/gapic_version.py')) as fp: + exec(fp.read(), version) +version = version["__version__"] + +if version[0] == "0": + release_status = "Development Status :: 4 - Beta" +else: + release_status = "Development Status :: 5 - Production/Stable" + +dependencies = [ + "google-api-core[grpc] >= 1.34.0, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*", + "proto-plus >= 1.22.0, <2.0.0dev", + "proto-plus >= 1.22.2, <2.0.0dev; python_version>='3.11'", + "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", +] +url = "https://github.com/googleapis/python-ai-generativelanguage" + +package_root = os.path.abspath(os.path.dirname(__file__)) + +readme_filename = os.path.join(package_root, "README.rst") +with io.open(readme_filename, encoding="utf-8") as readme_file: + readme = readme_file.read() + +packages = [ + package + for package in setuptools.PEP420PackageFinder.find() + if package.startswith("google") +] + +namespaces = ["google", "google.ai"] + +setuptools.setup( + name=name, + version=version, + description=description, + long_description=readme, + author="Google LLC", + author_email="googleapis-packages@google.com", + license="Apache 2.0", + url=url, + classifiers=[ + release_status, + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Operating System :: OS Independent", + "Topic :: Internet", + ], + platforms="Posix; MacOS X; Windows", + packages=packages, + python_requires=">=3.7", + namespace_packages=namespaces, + install_requires=dependencies, + include_package_data=True, + zip_safe=False, +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt new file mode 100644 index 000000000000..6c44adfea7ee --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt @@ -0,0 +1,9 @@ +# This constraints file is used to check that lower bounds +# are correct in setup.py +# List all library dependencies and extras in this file. +# Pin the version to the lower bound. +# e.g., if setup.py has "google-cloud-foo >= 1.14.0, < 2.0.0dev", +# Then this file should have google-cloud-foo==1.14.0 +google-api-core==1.34.0 +proto-plus==1.22.0 +protobuf==3.19.5 diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt new file mode 100644 index 000000000000..ed7f9aed2559 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# This constraints file is required for unit tests. +# List all library dependencies and extras in this file. +google-api-core +proto-plus +protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py new file mode 100644 index 000000000000..1b4db446eb8d --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py @@ -0,0 +1,16 @@ + +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py new file mode 100644 index 000000000000..fff2f9a81134 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py @@ -0,0 +1,2206 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.ai.generativelanguage_v1beta3.services.discuss_service import DiscussServiceAsyncClient +from google.ai.generativelanguage_v1beta3.services.discuss_service import DiscussServiceClient +from google.ai.generativelanguage_v1beta3.services.discuss_service import transports +from google.ai.generativelanguage_v1beta3.types import citation +from google.ai.generativelanguage_v1beta3.types import discuss_service +from google.ai.generativelanguage_v1beta3.types import safety +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert DiscussServiceClient._get_default_mtls_endpoint(None) is None + assert DiscussServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert DiscussServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class,transport_name", [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), +]) +def test_discuss_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +@pytest.mark.parametrize("transport_class,transport_name", [ + (transports.DiscussServiceGrpcTransport, "grpc"), + (transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.DiscussServiceRestTransport, "rest"), +]) +def test_discuss_service_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class,transport_name", [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), +]) +def test_discuss_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +def test_discuss_service_client_get_transport_class(): + transport = DiscussServiceClient.get_transport_class() + available_transports = [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceRestTransport, + ] + assert transport in available_transports + + transport = DiscussServiceClient.get_transport_class("grpc") + assert transport == transports.DiscussServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), +]) +@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) +@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) +def test_discuss_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions(api_audience="https://language.googleapis.com") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com" + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "true"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "false"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "true"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "false"), +]) +@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) +@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_discuss_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ + DiscussServiceClient, DiscussServiceAsyncClient +]) +@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) +@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) +def test_discuss_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), +]) +def test_discuss_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", None), +]) +def test_discuss_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +def test_discuss_service_client_client_options_from_dict(): + with mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = DiscussServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), +]) +def test_discuss_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( +), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.GenerateMessageRequest, + dict, +]) +def test_generate_message(request_type, transport: str = 'grpc'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse( + ) + response = client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.GenerateMessageRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +def test_generate_message_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + client.generate_message() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.GenerateMessageRequest() + +@pytest.mark.asyncio +async def test_generate_message_async(transport: str = 'grpc_asyncio', request_type=discuss_service.GenerateMessageRequest): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse( + )) + response = await client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.GenerateMessageRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +@pytest.mark.asyncio +async def test_generate_message_async_from_dict(): + await test_generate_message_async(request_type=dict) + + +def test_generate_message_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.GenerateMessageRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + call.return_value = discuss_service.GenerateMessageResponse() + client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_generate_message_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.GenerateMessageRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) + await client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_generate_message_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_message( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +def test_generate_message_flattened_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_message( + discuss_service.GenerateMessageRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + +@pytest.mark.asyncio +async def test_generate_message_flattened_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_message), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_message( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + +@pytest.mark.asyncio +async def test_generate_message_flattened_error_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_message( + discuss_service.GenerateMessageRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.CountMessageTokensRequest, + dict, +]) +def test_count_message_tokens(request_type, transport: str = 'grpc'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + response = client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.CountMessageTokensRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +def test_count_message_tokens_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + client.count_message_tokens() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.CountMessageTokensRequest() + +@pytest.mark.asyncio +async def test_count_message_tokens_async(transport: str = 'grpc_asyncio', request_type=discuss_service.CountMessageTokensRequest): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse( + token_count=1193, + )) + response = await client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.CountMessageTokensRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +@pytest.mark.asyncio +async def test_count_message_tokens_async_from_dict(): + await test_count_message_tokens_async(request_type=dict) + + +def test_count_message_tokens_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.CountMessageTokensRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + call.return_value = discuss_service.CountMessageTokensResponse() + client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_count_message_tokens_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.CountMessageTokensRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) + await client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_count_message_tokens_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.count_message_tokens( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + + +def test_count_message_tokens_flattened_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + +@pytest.mark.asyncio +async def test_count_message_tokens_flattened_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.count_message_tokens( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context='context_value') + assert arg == mock_val + +@pytest.mark.asyncio +async def test_count_message_tokens_flattened_error_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.GenerateMessageRequest, + dict, +]) +def test_generate_message_rest(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.generate_message(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +def test_generate_message_rest_required_fields(request_type=discuss_service.GenerateMessageRequest): + transport_class = transports.DiscussServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.generate_message(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_generate_message_rest_unset_required_fields(): + transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.generate_message._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_message_rest_interceptors(null_interceptor): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), + ) + client = DiscussServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "post_generate_message") as post, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_generate_message") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = discuss_service.GenerateMessageRequest.pb(discuss_service.GenerateMessageRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = discuss_service.GenerateMessageResponse.to_json(discuss_service.GenerateMessageResponse()) + + request = discuss_service.GenerateMessageRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = discuss_service.GenerateMessageResponse() + + client.generate_message(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_generate_message_rest_bad_request(transport: str = 'rest', request_type=discuss_service.GenerateMessageRequest): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.generate_message(request) + + +def test_generate_message_rest_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.generate_message(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{model=models/*}:generateMessage" % client.transport._host, args[1]) + + +def test_generate_message_rest_flattened_error(transport: str = 'rest'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_message( + discuss_service.GenerateMessageRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + +def test_generate_message_rest_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + discuss_service.CountMessageTokensRequest, + dict, +]) +def test_count_message_tokens_rest(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.count_message_tokens(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +def test_count_message_tokens_rest_required_fields(request_type=discuss_service.CountMessageTokensRequest): + transport_class = transports.DiscussServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.count_message_tokens(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_count_message_tokens_rest_unset_required_fields(): + transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.count_message_tokens._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_count_message_tokens_rest_interceptors(null_interceptor): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), + ) + client = DiscussServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "post_count_message_tokens") as post, \ + mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_count_message_tokens") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = discuss_service.CountMessageTokensRequest.pb(discuss_service.CountMessageTokensRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = discuss_service.CountMessageTokensResponse.to_json(discuss_service.CountMessageTokensResponse()) + + request = discuss_service.CountMessageTokensRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = discuss_service.CountMessageTokensResponse() + + client.count_message_tokens(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_count_message_tokens_rest_bad_request(transport: str = 'rest', request_type=discuss_service.CountMessageTokensRequest): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.count_message_tokens(request) + + +def test_count_message_tokens_rest_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.count_message_tokens(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{model=models/*}:countMessageTokens" % client.transport._host, args[1]) + + +def test_count_message_tokens_rest_flattened_error(transport: str = 'rest'): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model='model_value', + prompt=discuss_service.MessagePrompt(context='context_value'), + ) + + +def test_count_message_tokens_rest_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options=options, + credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = DiscussServiceClient(transport=transport) + assert client.transport is transport + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.DiscussServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + +@pytest.mark.parametrize("transport_class", [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + transports.DiscussServiceRestTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "rest", +]) +def test_transport_kind(transport_name): + transport = DiscussServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.DiscussServiceGrpcTransport, + ) + +def test_discuss_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.DiscussServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_discuss_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.DiscussServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'generate_message', + 'count_message_tokens', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + 'kind', + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_discuss_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.DiscussServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", + scopes=None, + default_scopes=( +), + quota_project_id="octopus", + ) + + +def test_discuss_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.DiscussServiceTransport() + adc.assert_called_once() + + +def test_discuss_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + DiscussServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=( +), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + transports.DiscussServiceRestTransport, + ], +) +def test_discuss_service_transport_auth_gdch_credentials(transport_class): + host = 'https://language.com' + api_audience_tests = [None, 'https://language2.com'] + api_audience_expect = [host, 'https://language2.com'] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with( + e + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.DiscussServiceGrpcTransport, grpc_helpers), + (transports.DiscussServiceGrpcAsyncIOTransport, grpc_helpers_async) + ], +) +def test_discuss_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class( + quota_project_id="octopus", + scopes=["1", "2"] + ) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( +), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) +def test_discuss_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + +def test_discuss_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.DiscussServiceRestTransport ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_discuss_service_host_no_port(transport_name): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com' + ) + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_discuss_service_host_with_port(transport_name): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:8000' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com:8000' + ) + +@pytest.mark.parametrize("transport_name", [ + "rest", +]) +def test_discuss_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = DiscussServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = DiscussServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.generate_message._session + session2 = client2.transport.generate_message._session + assert session1 != session2 + session1 = client1.transport.count_message_tokens._session + session2 = client2.transport.count_message_tokens._session + assert session1 != session2 +def test_discuss_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DiscussServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_discuss_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DiscussServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) +def test_discuss_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) +def test_discuss_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format(model=model, ) + actual = DiscussServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = DiscussServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = DiscussServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = DiscussServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format(folder=folder, ) + actual = DiscussServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = DiscussServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format(organization=organization, ) + actual = DiscussServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = DiscussServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format(project=project, ) + actual = DiscussServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = DiscussServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = DiscussServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = DiscussServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = DiscussServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = DiscussServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with client: + close.assert_not_called() + close.assert_called_once() + +def test_client_ctx(): + transports = [ + 'rest', + 'grpc', + ] + for transport in transports: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + +@pytest.mark.parametrize("client_class,transport_class", [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport), +]) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py new file mode 100644 index 000000000000..b0c5932de677 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py @@ -0,0 +1,4869 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.ai.generativelanguage_v1beta3.services.model_service import ModelServiceAsyncClient +from google.ai.generativelanguage_v1beta3.services.model_service import ModelServiceClient +from google.ai.generativelanguage_v1beta3.services.model_service import pagers +from google.ai.generativelanguage_v1beta3.services.model_service import transports +from google.ai.generativelanguage_v1beta3.types import model +from google.ai.generativelanguage_v1beta3.types import model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ModelServiceClient._get_default_mtls_endpoint(None) is None + assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class,transport_name", [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), +]) +def test_model_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +@pytest.mark.parametrize("transport_class,transport_name", [ + (transports.ModelServiceGrpcTransport, "grpc"), + (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.ModelServiceRestTransport, "rest"), +]) +def test_model_service_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class,transport_name", [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), +]) +def test_model_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +def test_model_service_client_get_transport_class(): + transport = ModelServiceClient.get_transport_class() + available_transports = [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceRestTransport, + ] + assert transport in available_transports + + transport = ModelServiceClient.get_transport_class("grpc") + assert transport == transports.ModelServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions(api_audience="https://language.googleapis.com") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com" + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "true"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "false"), +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ + ModelServiceClient, ModelServiceAsyncClient +]) +@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) +@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), +]) +def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", None), +]) +def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +def test_model_service_client_client_options_from_dict(): + with mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = ModelServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), +]) +def test_model_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( +), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.GetModelRequest, + dict, +]) +def test_get_model(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model( + name='name_value', + base_model_id='base_model_id_value', + version='version_value', + display_name='display_name_value', + description='description_value', + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=['supported_generation_methods_value'], + temperature=0.1198, + top_p=0.546, + top_k=541, + ) + response = client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == 'name_value' + assert response.base_model_id == 'base_model_id_value' + assert response.version == 'version_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +def test_get_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + client.get_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetModelRequest() + +@pytest.mark.asyncio +async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model.Model( + name='name_value', + base_model_id='base_model_id_value', + version='version_value', + display_name='display_name_value', + description='description_value', + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=['supported_generation_methods_value'], + temperature=0.1198, + top_p=0.546, + top_k=541, + )) + response = await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == 'name_value' + assert response.base_model_id == 'base_model_id_value' + assert response.version == 'version_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +@pytest.mark.asyncio +async def test_get_model_async_from_dict(): + await test_get_model_async(request_type=dict) + + +def test_get_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + call.return_value = model.Model() + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +def test_get_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + + +def test_get_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + +@pytest.mark.asyncio +async def test_get_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_get_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.ListModelsRequest, + dict, +]) +def test_list_models(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse( + next_page_token='next_page_token_value', + ) + response = client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + assert response.next_page_token == 'next_page_token_value' + + +def test_list_models_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + client.list_models() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListModelsRequest() + +@pytest.mark.asyncio +async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( + next_page_token='next_page_token_value', + )) + response = await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsAsyncPager) + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_models_async_from_dict(): + await test_list_models_async(request_type=dict) + + +def test_list_models_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_models( + page_size=951, + page_token='page_token_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = 'page_token_value' + assert arg == mock_val + + +def test_list_models_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + +@pytest.mark.asyncio +async def test_list_models_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_models( + page_size=951, + page_token='page_token_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = 'page_token_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_list_models_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + + +def test_list_models_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + + metadata = () + pager = client.list_models(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, model.Model) + for i in results) +def test_list_models_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = list(client.list_models(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_models_async_pager(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_models(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model.Model) + for i in responses) + + +@pytest.mark.asyncio +async def test_list_models_async_pages(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_models(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.parametrize("request_type", [ + model_service.GetTunedModelRequest, + dict, +]) +def test_get_tuned_model(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = tuned_model.TunedModel( + name='name_value', + display_name='display_name_value', + description='description_value', + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + base_model='base_model_value', + ) + response = client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetTunedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, tuned_model.TunedModel) + assert response.name == 'name_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == tuned_model.TunedModel.State.CREATING + + +def test_get_tuned_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tuned_model), + '__call__') as call: + client.get_tuned_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetTunedModelRequest() + +@pytest.mark.asyncio +async def test_get_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetTunedModelRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(tuned_model.TunedModel( + name='name_value', + display_name='display_name_value', + description='description_value', + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + )) + response = await client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetTunedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, tuned_model.TunedModel) + assert response.name == 'name_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == tuned_model.TunedModel.State.CREATING + + +@pytest.mark.asyncio +async def test_get_tuned_model_async_from_dict(): + await test_get_tuned_model_async(request_type=dict) + + +def test_get_tuned_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetTunedModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tuned_model), + '__call__') as call: + call.return_value = tuned_model.TunedModel() + client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_tuned_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetTunedModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tuned_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(tuned_model.TunedModel()) + await client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +def test_get_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = tuned_model.TunedModel() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_tuned_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + + +def test_get_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tuned_model( + model_service.GetTunedModelRequest(), + name='name_value', + ) + +@pytest.mark.asyncio +async def test_get_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = tuned_model.TunedModel() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(tuned_model.TunedModel()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_tuned_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_get_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_tuned_model( + model_service.GetTunedModelRequest(), + name='name_value', + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.ListTunedModelsRequest, + dict, +]) +def test_list_tuned_models(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListTunedModelsResponse( + next_page_token='next_page_token_value', + ) + response = client.list_tuned_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListTunedModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTunedModelsPager) + assert response.next_page_token == 'next_page_token_value' + + +def test_list_tuned_models_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__') as call: + client.list_tuned_models() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListTunedModelsRequest() + +@pytest.mark.asyncio +async def test_list_tuned_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListTunedModelsRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListTunedModelsResponse( + next_page_token='next_page_token_value', + )) + response = await client.list_tuned_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListTunedModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTunedModelsAsyncPager) + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_tuned_models_async_from_dict(): + await test_list_tuned_models_async(request_type=dict) + + +def test_list_tuned_models_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListTunedModelsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_tuned_models( + page_size=951, + page_token='page_token_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = 'page_token_value' + assert arg == mock_val + + +def test_list_tuned_models_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tuned_models( + model_service.ListTunedModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + +@pytest.mark.asyncio +async def test_list_tuned_models_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListTunedModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListTunedModelsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_tuned_models( + page_size=951, + page_token='page_token_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = 'page_token_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_list_tuned_models_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_tuned_models( + model_service.ListTunedModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + + +def test_list_tuned_models_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token='abc', + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token='def', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token='ghi', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + + metadata = () + pager = client.list_tuned_models(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, tuned_model.TunedModel) + for i in results) +def test_list_tuned_models_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token='abc', + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token='def', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token='ghi', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + pages = list(client.list_tuned_models(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_tuned_models_async_pager(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token='abc', + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token='def', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token='ghi', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_tuned_models(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, tuned_model.TunedModel) + for i in responses) + + +@pytest.mark.asyncio +async def test_list_tuned_models_async_pages(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token='abc', + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token='def', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token='ghi', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_tuned_models(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.parametrize("request_type", [ + model_service.CreateTunedModelRequest, + dict, +]) +def test_create_tuned_model(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/spam') + response = client.create_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.CreateTunedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_tuned_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), + '__call__') as call: + client.create_tuned_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.CreateTunedModelRequest() + +@pytest.mark.asyncio +async def test_create_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.CreateTunedModelRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + response = await client.create_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.CreateTunedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_tuned_model_async_from_dict(): + await test_create_tuned_model_async(request_type=dict) + + +def test_create_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_tuned_model( + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + tuned_model_id='tuned_model_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + assert arg == mock_val + arg = args[0].tuned_model_id + mock_val = 'tuned_model_id_value' + assert arg == mock_val + + +def test_create_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tuned_model( + model_service.CreateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + tuned_model_id='tuned_model_id_value', + ) + +@pytest.mark.asyncio +async def test_create_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name='operations/op') + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name='operations/spam') + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_tuned_model( + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + tuned_model_id='tuned_model_id_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + assert arg == mock_val + arg = args[0].tuned_model_id + mock_val = 'tuned_model_id_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_create_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_tuned_model( + model_service.CreateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + tuned_model_id='tuned_model_id_value', + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.UpdateTunedModelRequest, + dict, +]) +def test_update_tuned_model(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_tuned_model.TunedModel( + name='name_value', + display_name='display_name_value', + description='description_value', + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + base_model='base_model_value', + ) + response = client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.UpdateTunedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_tuned_model.TunedModel) + assert response.name == 'name_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == gag_tuned_model.TunedModel.State.CREATING + + +def test_update_tuned_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), + '__call__') as call: + client.update_tuned_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.UpdateTunedModelRequest() + +@pytest.mark.asyncio +async def test_update_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateTunedModelRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(gag_tuned_model.TunedModel( + name='name_value', + display_name='display_name_value', + description='description_value', + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + )) + response = await client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.UpdateTunedModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_tuned_model.TunedModel) + assert response.name == 'name_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == gag_tuned_model.TunedModel.State.CREATING + + +@pytest.mark.asyncio +async def test_update_tuned_model_async_from_dict(): + await test_update_tuned_model_async(request_type=dict) + + +def test_update_tuned_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateTunedModelRequest() + + request.tuned_model.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), + '__call__') as call: + call.return_value = gag_tuned_model.TunedModel() + client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'tuned_model.name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_tuned_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateTunedModelRequest() + + request.tuned_model.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_tuned_model.TunedModel()) + await client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'tuned_model.name=name_value', + ) in kw['metadata'] + + +def test_update_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_tuned_model.TunedModel() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_tuned_model( + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + assert arg == mock_val + + +def test_update_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tuned_model( + model_service.UpdateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + +@pytest.mark.asyncio +async def test_update_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_tuned_model.TunedModel() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_tuned_model.TunedModel()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_tuned_model( + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + assert arg == mock_val + +@pytest.mark.asyncio +async def test_update_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_tuned_model( + model_service.UpdateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.DeleteTunedModelRequest, + dict, +]) +def test_delete_tuned_model(request_type, transport: str = 'grpc'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.DeleteTunedModelRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_tuned_model_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), + '__call__') as call: + client.delete_tuned_model() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.DeleteTunedModelRequest() + +@pytest.mark.asyncio +async def test_delete_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteTunedModelRequest): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.DeleteTunedModelRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_tuned_model_async_from_dict(): + await test_delete_tuned_model_async(request_type=dict) + + +def test_delete_tuned_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteTunedModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), + '__call__') as call: + call.return_value = None + client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_tuned_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteTunedModelRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +def test_delete_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_tuned_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + + +def test_delete_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tuned_model( + model_service.DeleteTunedModelRequest(), + name='name_value', + ) + +@pytest.mark.asyncio +async def test_delete_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_tuned_model( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_delete_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_tuned_model( + model_service.DeleteTunedModelRequest(), + name='name_value', + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.GetModelRequest, + dict, +]) +def test_get_model_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model.Model( + name='name_value', + base_model_id='base_model_id_value', + version='version_value', + display_name='display_name_value', + description='description_value', + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=['supported_generation_methods_value'], + temperature=0.1198, + top_p=0.546, + top_k=541, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.get_model(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == 'name_value' + assert response.base_model_id == 'base_model_id_value' + assert response.version == 'version_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +def test_get_model_rest_required_fields(request_type=model_service.GetModelRequest): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = 'name_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == 'name_value' + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = model.Model() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "get", + 'query_params': pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.get_model(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_get_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.get_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_get_model") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_get_model") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.GetModelRequest.pb(model_service.GetModelRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = model.Model.to_json(model.Model()) + + request = model_service.GetModelRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model.Model() + + client.get_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_model_rest_bad_request(transport: str = 'rest', request_type=model_service.GetModelRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_model(request) + + +def test_get_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model.Model() + + # get arguments that satisfy an http rule for this method + sample_request = {'name': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + name='name_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.get_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{name=models/*}" % client.transport._host, args[1]) + + +def test_get_model_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), + name='name_value', + ) + + +def test_get_model_rest_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.ListModelsRequest, + dict, +]) +def test_list_models_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListModelsResponse( + next_page_token='next_page_token_value', + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model_service.ListModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.list_models(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_models_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_list_models") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_list_models") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.ListModelsRequest.pb(model_service.ListModelsRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = model_service.ListModelsResponse.to_json(model_service.ListModelsResponse()) + + request = model_service.ListModelsRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model_service.ListModelsResponse() + + client.list_models(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_models_rest_bad_request(transport: str = 'rest', request_type=model_service.ListModelsRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_models(request) + + +def test_list_models_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListModelsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + page_size=951, + page_token='page_token_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model_service.ListModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.list_models(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/models" % client.transport._host, args[1]) + + +def test_list_models_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + + +def test_list_models_rest_pager(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + #with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token='abc', + ), + model_service.ListModelsResponse( + models=[], + next_page_token='def', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token='ghi', + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(model_service.ListModelsResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode('UTF-8') + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_models(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, model.Model) + for i in results) + + pages = list(client.list_models(request=sample_request).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize("request_type", [ + model_service.GetTunedModelRequest, + dict, +]) +def test_get_tuned_model_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = tuned_model.TunedModel( + name='name_value', + display_name='display_name_value', + description='description_value', + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + base_model='base_model_value', + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.get_tuned_model(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, tuned_model.TunedModel) + assert response.name == 'name_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == tuned_model.TunedModel.State.CREATING + + +def test_get_tuned_model_rest_required_fields(request_type=model_service.GetTunedModelRequest): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = 'name_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == 'name_value' + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = tuned_model.TunedModel() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "get", + 'query_params': pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.get_tuned_model(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_get_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.get_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_get_tuned_model") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_get_tuned_model") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.GetTunedModelRequest.pb(model_service.GetTunedModelRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = tuned_model.TunedModel.to_json(tuned_model.TunedModel()) + + request = model_service.GetTunedModelRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = tuned_model.TunedModel() + + client.get_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.GetTunedModelRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_tuned_model(request) + + +def test_get_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = tuned_model.TunedModel() + + # get arguments that satisfy an http rule for this method + sample_request = {'name': 'tunedModels/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + name='name_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.get_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{name=tunedModels/*}" % client.transport._host, args[1]) + + +def test_get_tuned_model_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tuned_model( + model_service.GetTunedModelRequest(), + name='name_value', + ) + + +def test_get_tuned_model_rest_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.ListTunedModelsRequest, + dict, +]) +def test_list_tuned_models_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListTunedModelsResponse( + next_page_token='next_page_token_value', + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model_service.ListTunedModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.list_tuned_models(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTunedModelsPager) + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_tuned_models_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_list_tuned_models") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_list_tuned_models") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.ListTunedModelsRequest.pb(model_service.ListTunedModelsRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = model_service.ListTunedModelsResponse.to_json(model_service.ListTunedModelsResponse()) + + request = model_service.ListTunedModelsRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model_service.ListTunedModelsResponse() + + client.list_tuned_models(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_tuned_models_rest_bad_request(transport: str = 'rest', request_type=model_service.ListTunedModelsRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_tuned_models(request) + + +def test_list_tuned_models_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListTunedModelsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + page_size=951, + page_token='page_token_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = model_service.ListTunedModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.list_tuned_models(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/tunedModels" % client.transport._host, args[1]) + + +def test_list_tuned_models_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tuned_models( + model_service.ListTunedModelsRequest(), + page_size=951, + page_token='page_token_value', + ) + + +def test_list_tuned_models_rest_pager(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + #with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token='abc', + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token='def', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token='ghi', + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(model_service.ListTunedModelsResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode('UTF-8') + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_tuned_models(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, tuned_model.TunedModel) + for i in results) + + pages = list(client.list_tuned_models(request=sample_request).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize("request_type", [ + model_service.CreateTunedModelRequest, + dict, +]) +def test_create_tuned_model_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {} + request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'name_value', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name='operations/spam') + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.create_tuned_model(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_tuned_model_rest_required_fields(request_type=model_service.CreateTunedModelRequest): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_tuned_model._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("tuned_model_id", )) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name='operations/spam') + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.create_tuned_model(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_create_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.create_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(("tunedModelId", )) & set(("tunedModel", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(operation.Operation, "_set_result_from_operation"), \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_create_tuned_model") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_create_tuned_model") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.CreateTunedModelRequest.pb(model_service.CreateTunedModelRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson(operations_pb2.Operation()) + + request = model_service.CreateTunedModelRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.CreateTunedModelRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {} + request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'name_value', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_tuned_model(request) + + +def test_create_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name='operations/spam') + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + tuned_model_id='tuned_model_id_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.create_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/tunedModels" % client.transport._host, args[1]) + + +def test_create_tuned_model_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tuned_model( + model_service.CreateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + tuned_model_id='tuned_model_id_value', + ) + + +def test_create_tuned_model_rest_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.UpdateTunedModelRequest, + dict, +]) +def test_update_tuned_model_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'tuned_model': {'name': 'tunedModels/sample1'}} + request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'tunedModels/sample1', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = gag_tuned_model.TunedModel( + name='name_value', + display_name='display_name_value', + description='description_value', + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + base_model='base_model_value', + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = gag_tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.update_tuned_model(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_tuned_model.TunedModel) + assert response.name == 'name_value' + assert response.display_name == 'display_name_value' + assert response.description == 'description_value' + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == gag_tuned_model.TunedModel.State.CREATING + + +def test_update_tuned_model_rest_required_fields(request_type=model_service.UpdateTunedModelRequest): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_tuned_model._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask", )) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_tuned_model.TunedModel() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "patch", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = gag_tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.update_tuned_model(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_update_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.update_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(("updateMask", )) & set(("tunedModel", "updateMask", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "post_update_tuned_model") as post, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_update_tuned_model") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.UpdateTunedModelRequest.pb(model_service.UpdateTunedModelRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = gag_tuned_model.TunedModel.to_json(gag_tuned_model.TunedModel()) + + request = model_service.UpdateTunedModelRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_tuned_model.TunedModel() + + client.update_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.UpdateTunedModelRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'tuned_model': {'name': 'tunedModels/sample1'}} + request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'tunedModels/sample1', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_tuned_model(request) + + +def test_update_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = gag_tuned_model.TunedModel() + + # get arguments that satisfy an http rule for this method + sample_request = {'tuned_model': {'name': 'tunedModels/sample1'}} + + # get truthy value for each flattened field + mock_args = dict( + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = gag_tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.update_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{tuned_model.name=tunedModels/*}" % client.transport._host, args[1]) + + +def test_update_tuned_model_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tuned_model( + model_service.UpdateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + +def test_update_tuned_model_rest_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + model_service.DeleteTunedModelRequest, + dict, +]) +def test_delete_tuned_model_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = '' + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.delete_tuned_model(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_tuned_model_rest_required_fields(request_type=model_service.DeleteTunedModelRequest): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = 'name_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == 'name_value' + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "delete", + 'query_params': pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = '' + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.delete_tuned_model(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_delete_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.delete_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.ModelServiceRestInterceptor, "pre_delete_tuned_model") as pre: + pre.assert_not_called() + pb_message = model_service.DeleteTunedModelRequest.pb(model_service.DeleteTunedModelRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = model_service.DeleteTunedModelRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + + +def test_delete_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.DeleteTunedModelRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_tuned_model(request) + + +def test_delete_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {'name': 'tunedModels/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + name='name_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = '' + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.delete_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{name=tunedModels/*}" % client.transport._host, args[1]) + + +def test_delete_tuned_model_rest_flattened_error(transport: str = 'rest'): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tuned_model( + model_service.DeleteTunedModelRequest(), + name='name_value', + ) + + +def test_delete_tuned_model_rest_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, + credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = ModelServiceClient(transport=transport) + assert client.transport is transport + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ModelServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + +@pytest.mark.parametrize("transport_class", [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + transports.ModelServiceRestTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "rest", +]) +def test_transport_kind(transport_name): + transport = ModelServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelServiceGrpcTransport, + ) + +def test_model_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.ModelServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_model_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.ModelServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'get_model', + 'list_models', + 'get_tuned_model', + 'list_tuned_models', + 'create_tuned_model', + 'update_tuned_model', + 'delete_tuned_model', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + # Catch all for all remaining methods and properties + remainder = [ + 'kind', + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_model_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", + scopes=None, + default_scopes=( +), + quota_project_id="octopus", + ) + + +def test_model_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport() + adc.assert_called_once() + + +def test_model_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + ModelServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=( +), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + ], +) +def test_model_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + transports.ModelServiceRestTransport, + ], +) +def test_model_service_transport_auth_gdch_credentials(transport_class): + host = 'https://language.com' + api_audience_tests = [None, 'https://language2.com'] + api_audience_expect = [host, 'https://language2.com'] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with( + e + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.ModelServiceGrpcTransport, grpc_helpers), + (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async) + ], +) +def test_model_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class( + quota_project_id="octopus", + scopes=["1", "2"] + ) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( +), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + +def test_model_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.ModelServiceRestTransport ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_model_service_rest_lro_client(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_model_service_host_no_port(transport_name): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com' + ) + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_model_service_host_with_port(transport_name): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:8000' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com:8000' + ) + +@pytest.mark.parametrize("transport_name", [ + "rest", +]) +def test_model_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = ModelServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = ModelServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.get_model._session + session2 = client2.transport.get_model._session + assert session1 != session2 + session1 = client1.transport.list_models._session + session2 = client2.transport.list_models._session + assert session1 != session2 + session1 = client1.transport.get_tuned_model._session + session2 = client2.transport.get_tuned_model._session + assert session1 != session2 + session1 = client1.transport.list_tuned_models._session + session2 = client2.transport.list_tuned_models._session + assert session1 != session2 + session1 = client1.transport.create_tuned_model._session + session2 = client2.transport.create_tuned_model._session + assert session1 != session2 + session1 = client1.transport.update_tuned_model._session + session2 = client2.transport.update_tuned_model._session + assert session1 != session2 + session1 = client1.transport.delete_tuned_model._session + session2 = client2.transport.delete_tuned_model._session + assert session1 != session2 +def test_model_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_model_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) +def test_model_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_service_grpc_lro_client(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_service_grpc_lro_async_client(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc_asyncio', + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format(model=model, ) + actual = ModelServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = ModelServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_path(path) + assert expected == actual + +def test_tuned_model_path(): + tuned_model = "whelk" + expected = "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + actual = ModelServiceClient.tuned_model_path(tuned_model) + assert expected == actual + + +def test_parse_tuned_model_path(): + expected = { + "tuned_model": "octopus", + } + path = ModelServiceClient.tuned_model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_tuned_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "oyster" + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = ModelServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = ModelServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "cuttlefish" + expected = "folders/{folder}".format(folder=folder, ) + actual = ModelServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = ModelServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "winkle" + expected = "organizations/{organization}".format(organization=organization, ) + actual = ModelServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = ModelServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "scallop" + expected = "projects/{project}".format(project=project, ) + actual = ModelServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = ModelServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "squid" + location = "clam" + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = ModelServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = ModelServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = ModelServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with client: + close.assert_not_called() + close.assert_called_once() + +def test_client_ctx(): + transports = [ + 'rest', + 'grpc', + ] + for transport in transports: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + +@pytest.mark.parametrize("client_class,transport_class", [ + (ModelServiceClient, transports.ModelServiceGrpcTransport), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), +]) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py new file mode 100644 index 000000000000..aa9954df98a8 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py @@ -0,0 +1,4220 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.ai.generativelanguage_v1beta3.services.permission_service import PermissionServiceAsyncClient +from google.ai.generativelanguage_v1beta3.services.permission_service import PermissionServiceClient +from google.ai.generativelanguage_v1beta3.services.permission_service import pagers +from google.ai.generativelanguage_v1beta3.services.permission_service import transports +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 # type: ignore +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PermissionServiceClient._get_default_mtls_endpoint(None) is None + assert PermissionServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert PermissionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert PermissionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert PermissionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert PermissionServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class,transport_name", [ + (PermissionServiceClient, "grpc"), + (PermissionServiceAsyncClient, "grpc_asyncio"), + (PermissionServiceClient, "rest"), +]) +def test_permission_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +@pytest.mark.parametrize("transport_class,transport_name", [ + (transports.PermissionServiceGrpcTransport, "grpc"), + (transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.PermissionServiceRestTransport, "rest"), +]) +def test_permission_service_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class,transport_name", [ + (PermissionServiceClient, "grpc"), + (PermissionServiceAsyncClient, "grpc_asyncio"), + (PermissionServiceClient, "rest"), +]) +def test_permission_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +def test_permission_service_client_get_transport_class(): + transport = PermissionServiceClient.get_transport_class() + available_transports = [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceRestTransport, + ] + assert transport in available_transports + + transport = PermissionServiceClient.get_transport_class("grpc") + assert transport == transports.PermissionServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), + (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), +]) +@mock.patch.object(PermissionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceClient)) +@mock.patch.object(PermissionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceAsyncClient)) +def test_permission_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PermissionServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PermissionServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions(api_audience="https://language.googleapis.com") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com" + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", "true"), + (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", "false"), + (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest", "true"), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest", "false"), +]) +@mock.patch.object(PermissionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceClient)) +@mock.patch.object(PermissionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_permission_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ + PermissionServiceClient, PermissionServiceAsyncClient +]) +@mock.patch.object(PermissionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceClient)) +@mock.patch.object(PermissionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceAsyncClient)) +def test_permission_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), + (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), +]) +def test_permission_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", grpc_helpers), + (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest", None), +]) +def test_permission_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +def test_permission_service_client_client_options_from_dict(): + with mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = PermissionServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", grpc_helpers), + (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), +]) +def test_permission_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( +), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.CreatePermissionRequest, + dict, +]) +def test_create_permission(request_type, transport: str = 'grpc'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission( + name='name_value', + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=gag_permission.Permission.Role.OWNER, + ) + response = client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.CreatePermissionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == gag_permission.Permission.Role.OWNER + + +def test_create_permission_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), + '__call__') as call: + client.create_permission() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.CreatePermissionRequest() + +@pytest.mark.asyncio +async def test_create_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.CreatePermissionRequest): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission( + name='name_value', + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=gag_permission.Permission.Role.OWNER, + )) + response = await client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.CreatePermissionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == gag_permission.Permission.Role.OWNER + + +@pytest.mark.asyncio +async def test_create_permission_async_from_dict(): + await test_create_permission_async(request_type=dict) + + +def test_create_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.CreatePermissionRequest() + + request.parent = 'parent_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), + '__call__') as call: + call.return_value = gag_permission.Permission() + client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_create_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.CreatePermissionRequest() + + request.parent = 'parent_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + await client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent_value', + ) in kw['metadata'] + + +def test_create_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_permission( + parent='parent_value', + permission=gag_permission.Permission(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = 'parent_value' + assert arg == mock_val + arg = args[0].permission + mock_val = gag_permission.Permission(name='name_value') + assert arg == mock_val + + +def test_create_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_permission( + permission_service.CreatePermissionRequest(), + parent='parent_value', + permission=gag_permission.Permission(name='name_value'), + ) + +@pytest.mark.asyncio +async def test_create_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_permission( + parent='parent_value', + permission=gag_permission.Permission(name='name_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = 'parent_value' + assert arg == mock_val + arg = args[0].permission + mock_val = gag_permission.Permission(name='name_value') + assert arg == mock_val + +@pytest.mark.asyncio +async def test_create_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_permission( + permission_service.CreatePermissionRequest(), + parent='parent_value', + permission=gag_permission.Permission(name='name_value'), + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.GetPermissionRequest, + dict, +]) +def test_get_permission(request_type, transport: str = 'grpc'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = permission.Permission( + name='name_value', + grantee_type=permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=permission.Permission.Role.OWNER, + ) + response = client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.GetPermissionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == permission.Permission.Role.OWNER + + +def test_get_permission_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_permission), + '__call__') as call: + client.get_permission() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.GetPermissionRequest() + +@pytest.mark.asyncio +async def test_get_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.GetPermissionRequest): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(permission.Permission( + name='name_value', + grantee_type=permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=permission.Permission.Role.OWNER, + )) + response = await client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.GetPermissionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == permission.Permission.Role.OWNER + + +@pytest.mark.asyncio +async def test_get_permission_async_from_dict(): + await test_get_permission_async(request_type=dict) + + +def test_get_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.GetPermissionRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_permission), + '__call__') as call: + call.return_value = permission.Permission() + client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_get_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.GetPermissionRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_permission), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission.Permission()) + await client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +def test_get_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = permission.Permission() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_permission( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + + +def test_get_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_permission( + permission_service.GetPermissionRequest(), + name='name_value', + ) + +@pytest.mark.asyncio +async def test_get_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = permission.Permission() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission.Permission()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_permission( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_get_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_permission( + permission_service.GetPermissionRequest(), + name='name_value', + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.ListPermissionsRequest, + dict, +]) +def test_list_permissions(request_type, transport: str = 'grpc'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.ListPermissionsResponse( + next_page_token='next_page_token_value', + ) + response = client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.ListPermissionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListPermissionsPager) + assert response.next_page_token == 'next_page_token_value' + + +def test_list_permissions_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + client.list_permissions() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.ListPermissionsRequest() + +@pytest.mark.asyncio +async def test_list_permissions_async(transport: str = 'grpc_asyncio', request_type=permission_service.ListPermissionsRequest): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(permission_service.ListPermissionsResponse( + next_page_token='next_page_token_value', + )) + response = await client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.ListPermissionsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListPermissionsAsyncPager) + assert response.next_page_token == 'next_page_token_value' + + +@pytest.mark.asyncio +async def test_list_permissions_async_from_dict(): + await test_list_permissions_async(request_type=dict) + + +def test_list_permissions_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.ListPermissionsRequest() + + request.parent = 'parent_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + call.return_value = permission_service.ListPermissionsResponse() + client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_list_permissions_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.ListPermissionsRequest() + + request.parent = 'parent_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission_service.ListPermissionsResponse()) + await client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'parent=parent_value', + ) in kw['metadata'] + + +def test_list_permissions_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.ListPermissionsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_permissions( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = 'parent_value' + assert arg == mock_val + + +def test_list_permissions_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_permissions( + permission_service.ListPermissionsRequest(), + parent='parent_value', + ) + +@pytest.mark.asyncio +async def test_list_permissions_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.ListPermissionsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission_service.ListPermissionsResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_permissions( + parent='parent_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = 'parent_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_list_permissions_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_permissions( + permission_service.ListPermissionsRequest(), + parent='parent_value', + ) + + +def test_list_permissions_pager(transport_name: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token='abc', + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token='def', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token='ghi', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ('parent', ''), + )), + ) + pager = client.list_permissions(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, permission.Permission) + for i in results) +def test_list_permissions_pages(transport_name: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials, + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__') as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token='abc', + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token='def', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token='ghi', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + pages = list(client.list_permissions(request={}).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.asyncio +async def test_list_permissions_async_pager(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token='abc', + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token='def', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token='ghi', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_permissions(request={},) + assert async_pager.next_page_token == 'abc' + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, permission.Permission) + for i in responses) + + +@pytest.mark.asyncio +async def test_list_permissions_async_pages(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), + '__call__', new_callable=mock.AsyncMock) as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token='abc', + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token='def', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token='ghi', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_permissions(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + +@pytest.mark.parametrize("request_type", [ + permission_service.UpdatePermissionRequest, + dict, +]) +def test_update_permission(request_type, transport: str = 'grpc'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission( + name='name_value', + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=gag_permission.Permission.Role.OWNER, + ) + response = client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.UpdatePermissionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == gag_permission.Permission.Role.OWNER + + +def test_update_permission_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), + '__call__') as call: + client.update_permission() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.UpdatePermissionRequest() + +@pytest.mark.asyncio +async def test_update_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.UpdatePermissionRequest): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission( + name='name_value', + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=gag_permission.Permission.Role.OWNER, + )) + response = await client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.UpdatePermissionRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == gag_permission.Permission.Role.OWNER + + +@pytest.mark.asyncio +async def test_update_permission_async_from_dict(): + await test_update_permission_async(request_type=dict) + + +def test_update_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.UpdatePermissionRequest() + + request.permission.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), + '__call__') as call: + call.return_value = gag_permission.Permission() + client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'permission.name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_update_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.UpdatePermissionRequest() + + request.permission.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + await client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'permission.name=name_value', + ) in kw['metadata'] + + +def test_update_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_permission( + permission=gag_permission.Permission(name='name_value'), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].permission + mock_val = gag_permission.Permission(name='name_value') + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + assert arg == mock_val + + +def test_update_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_permission( + permission_service.UpdatePermissionRequest(), + permission=gag_permission.Permission(name='name_value'), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + +@pytest.mark.asyncio +async def test_update_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_permission( + permission=gag_permission.Permission(name='name_value'), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].permission + mock_val = gag_permission.Permission(name='name_value') + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + assert arg == mock_val + +@pytest.mark.asyncio +async def test_update_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_permission( + permission_service.UpdatePermissionRequest(), + permission=gag_permission.Permission(name='name_value'), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.DeletePermissionRequest, + dict, +]) +def test_delete_permission(request_type, transport: str = 'grpc'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.DeletePermissionRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_permission_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), + '__call__') as call: + client.delete_permission() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.DeletePermissionRequest() + +@pytest.mark.asyncio +async def test_delete_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.DeletePermissionRequest): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.DeletePermissionRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_permission_async_from_dict(): + await test_delete_permission_async(request_type=dict) + + +def test_delete_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.DeletePermissionRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), + '__call__') as call: + call.return_value = None + client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_delete_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.DeletePermissionRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +def test_delete_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_permission( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + + +def test_delete_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_permission( + permission_service.DeletePermissionRequest(), + name='name_value', + ) + +@pytest.mark.asyncio +async def test_delete_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_permission( + name='name_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = 'name_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_delete_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_permission( + permission_service.DeletePermissionRequest(), + name='name_value', + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.TransferOwnershipRequest, + dict, +]) +def test_transfer_ownership(request_type, transport: str = 'grpc'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.TransferOwnershipResponse( + ) + response = client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.TransferOwnershipRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, permission_service.TransferOwnershipResponse) + + +def test_transfer_ownership_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), + '__call__') as call: + client.transfer_ownership() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.TransferOwnershipRequest() + +@pytest.mark.asyncio +async def test_transfer_ownership_async(transport: str = 'grpc_asyncio', request_type=permission_service.TransferOwnershipRequest): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(permission_service.TransferOwnershipResponse( + )) + response = await client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.TransferOwnershipRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, permission_service.TransferOwnershipResponse) + + +@pytest.mark.asyncio +async def test_transfer_ownership_async_from_dict(): + await test_transfer_ownership_async(request_type=dict) + + +def test_transfer_ownership_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.TransferOwnershipRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), + '__call__') as call: + call.return_value = permission_service.TransferOwnershipResponse() + client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_transfer_ownership_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.TransferOwnershipRequest() + + request.name = 'name_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission_service.TransferOwnershipResponse()) + await client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'name=name_value', + ) in kw['metadata'] + + +@pytest.mark.parametrize("request_type", [ + permission_service.CreatePermissionRequest, + dict, +]) +def test_create_permission_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'parent': 'tunedModels/sample1'} + request_init["permission"] = {'name': 'name_value', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission( + name='name_value', + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=gag_permission.Permission.Role.OWNER, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.create_permission(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == gag_permission.Permission.Role.OWNER + + +def test_create_permission_rest_required_fields(request_type=permission_service.CreatePermissionRequest): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = 'parent_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == 'parent_value' + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.create_permission(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_create_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.create_permission._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("parent", "permission", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "post_create_permission") as post, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_create_permission") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.CreatePermissionRequest.pb(permission_service.CreatePermissionRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = gag_permission.Permission.to_json(gag_permission.Permission()) + + request = permission_service.CreatePermissionRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_permission.Permission() + + client.create_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.CreatePermissionRequest): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'parent': 'tunedModels/sample1'} + request_init["permission"] = {'name': 'name_value', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_permission(request) + + +def test_create_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + + # get arguments that satisfy an http rule for this method + sample_request = {'parent': 'tunedModels/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + parent='parent_value', + permission=gag_permission.Permission(name='name_value'), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.create_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{parent=tunedModels/*}/permissions" % client.transport._host, args[1]) + + +def test_create_permission_rest_flattened_error(transport: str = 'rest'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_permission( + permission_service.CreatePermissionRequest(), + parent='parent_value', + permission=gag_permission.Permission(name='name_value'), + ) + + +def test_create_permission_rest_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.GetPermissionRequest, + dict, +]) +def test_get_permission_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = permission.Permission( + name='name_value', + grantee_type=permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=permission.Permission.Role.OWNER, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.get_permission(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == permission.Permission.Role.OWNER + + +def test_get_permission_rest_required_fields(request_type=permission_service.GetPermissionRequest): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = 'name_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == 'name_value' + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = permission.Permission() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "get", + 'query_params': pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.get_permission(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_get_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.get_permission._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "post_get_permission") as post, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_get_permission") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.GetPermissionRequest.pb(permission_service.GetPermissionRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = permission.Permission.to_json(permission.Permission()) + + request = permission_service.GetPermissionRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = permission.Permission() + + client.get_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.GetPermissionRequest): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_permission(request) + + +def test_get_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = permission.Permission() + + # get arguments that satisfy an http rule for this method + sample_request = {'name': 'tunedModels/sample1/permissions/sample2'} + + # get truthy value for each flattened field + mock_args = dict( + name='name_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.get_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{name=tunedModels/*/permissions/*}" % client.transport._host, args[1]) + + +def test_get_permission_rest_flattened_error(transport: str = 'rest'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_permission( + permission_service.GetPermissionRequest(), + name='name_value', + ) + + +def test_get_permission_rest_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.ListPermissionsRequest, + dict, +]) +def test_list_permissions_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'parent': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = permission_service.ListPermissionsResponse( + next_page_token='next_page_token_value', + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = permission_service.ListPermissionsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.list_permissions(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListPermissionsPager) + assert response.next_page_token == 'next_page_token_value' + + +def test_list_permissions_rest_required_fields(request_type=permission_service.ListPermissionsRequest): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).list_permissions._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = 'parent_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).list_permissions._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("page_size", "page_token", )) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == 'parent_value' + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = permission_service.ListPermissionsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "get", + 'query_params': pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = permission_service.ListPermissionsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.list_permissions(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_list_permissions_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.list_permissions._get_unset_required_fields({}) + assert set(unset_fields) == (set(("pageSize", "pageToken", )) & set(("parent", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_permissions_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "post_list_permissions") as post, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_list_permissions") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.ListPermissionsRequest.pb(permission_service.ListPermissionsRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = permission_service.ListPermissionsResponse.to_json(permission_service.ListPermissionsResponse()) + + request = permission_service.ListPermissionsRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = permission_service.ListPermissionsResponse() + + client.list_permissions(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_permissions_rest_bad_request(transport: str = 'rest', request_type=permission_service.ListPermissionsRequest): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'parent': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_permissions(request) + + +def test_list_permissions_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = permission_service.ListPermissionsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'parent': 'tunedModels/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + parent='parent_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = permission_service.ListPermissionsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.list_permissions(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{parent=tunedModels/*}/permissions" % client.transport._host, args[1]) + + +def test_list_permissions_rest_flattened_error(transport: str = 'rest'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_permissions( + permission_service.ListPermissionsRequest(), + parent='parent_value', + ) + + +def test_list_permissions_rest_pager(transport: str = 'rest'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + #with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token='abc', + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token='def', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token='ghi', + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(permission_service.ListPermissionsResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode('UTF-8') + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {'parent': 'tunedModels/sample1'} + + pager = client.list_permissions(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, permission.Permission) + for i in results) + + pages = list(client.list_permissions(request=sample_request).pages) + for page_, token in zip(pages, ['abc','def','ghi', '']): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize("request_type", [ + permission_service.UpdatePermissionRequest, + dict, +]) +def test_update_permission_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'permission': {'name': 'tunedModels/sample1/permissions/sample2'}} + request_init["permission"] = {'name': 'tunedModels/sample1/permissions/sample2', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission( + name='name_value', + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address='email_address_value', + role=gag_permission.Permission.Role.OWNER, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.update_permission(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == 'name_value' + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == 'email_address_value' + assert response.role == gag_permission.Permission.Role.OWNER + + +def test_update_permission_rest_required_fields(request_type=permission_service.UpdatePermissionRequest): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_permission._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask", )) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "patch", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.update_permission(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_update_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.update_permission._get_unset_required_fields({}) + assert set(unset_fields) == (set(("updateMask", )) & set(("permission", "updateMask", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "post_update_permission") as post, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_update_permission") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.UpdatePermissionRequest.pb(permission_service.UpdatePermissionRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = gag_permission.Permission.to_json(gag_permission.Permission()) + + request = permission_service.UpdatePermissionRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_permission.Permission() + + client.update_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.UpdatePermissionRequest): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'permission': {'name': 'tunedModels/sample1/permissions/sample2'}} + request_init["permission"] = {'name': 'tunedModels/sample1/permissions/sample2', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_permission(request) + + +def test_update_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + + # get arguments that satisfy an http rule for this method + sample_request = {'permission': {'name': 'tunedModels/sample1/permissions/sample2'}} + + # get truthy value for each flattened field + mock_args = dict( + permission=gag_permission.Permission(name='name_value'), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.update_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{permission.name=tunedModels/*/permissions/*}" % client.transport._host, args[1]) + + +def test_update_permission_rest_flattened_error(transport: str = 'rest'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_permission( + permission_service.UpdatePermissionRequest(), + permission=gag_permission.Permission(name='name_value'), + update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + ) + + +def test_update_permission_rest_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.DeletePermissionRequest, + dict, +]) +def test_delete_permission_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = '' + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.delete_permission(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_permission_rest_required_fields(request_type=permission_service.DeletePermissionRequest): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = 'name_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == 'name_value' + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "delete", + 'query_params': pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = '' + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.delete_permission(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_delete_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.delete_permission._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_delete_permission") as pre: + pre.assert_not_called() + pb_message = permission_service.DeletePermissionRequest.pb(permission_service.DeletePermissionRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = permission_service.DeletePermissionRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + + +def test_delete_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.DeletePermissionRequest): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_permission(request) + + +def test_delete_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {'name': 'tunedModels/sample1/permissions/sample2'} + + # get truthy value for each flattened field + mock_args = dict( + name='name_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = '' + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.delete_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{name=tunedModels/*/permissions/*}" % client.transport._host, args[1]) + + +def test_delete_permission_rest_flattened_error(transport: str = 'rest'): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_permission( + permission_service.DeletePermissionRequest(), + name='name_value', + ) + + +def test_delete_permission_rest_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + permission_service.TransferOwnershipRequest, + dict, +]) +def test_transfer_ownership_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = permission_service.TransferOwnershipResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = permission_service.TransferOwnershipResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.transfer_ownership(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, permission_service.TransferOwnershipResponse) + + +def test_transfer_ownership_rest_required_fields(request_type=permission_service.TransferOwnershipRequest): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["name"] = "" + request_init["email_address"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).transfer_ownership._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = 'name_value' + jsonified_request["emailAddress"] = 'email_address_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).transfer_ownership._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == 'name_value' + assert "emailAddress" in jsonified_request + assert jsonified_request["emailAddress"] == 'email_address_value' + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = permission_service.TransferOwnershipResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = permission_service.TransferOwnershipResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.transfer_ownership(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_transfer_ownership_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.transfer_ownership._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name", "emailAddress", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_transfer_ownership_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "post_transfer_ownership") as post, \ + mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_transfer_ownership") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.TransferOwnershipRequest.pb(permission_service.TransferOwnershipRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = permission_service.TransferOwnershipResponse.to_json(permission_service.TransferOwnershipResponse()) + + request = permission_service.TransferOwnershipRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = permission_service.TransferOwnershipResponse() + + client.transfer_ownership(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_transfer_ownership_rest_bad_request(transport: str = 'rest', request_type=permission_service.TransferOwnershipRequest): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'name': 'tunedModels/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.transfer_ownership(request) + + +def test_transfer_ownership_rest_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options=options, + credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = PermissionServiceClient(transport=transport) + assert client.transport is transport + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PermissionServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + +@pytest.mark.parametrize("transport_class", [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + transports.PermissionServiceRestTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "rest", +]) +def test_transport_kind(transport_name): + transport = PermissionServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PermissionServiceGrpcTransport, + ) + +def test_permission_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.PermissionServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_permission_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.PermissionServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'create_permission', + 'get_permission', + 'list_permissions', + 'update_permission', + 'delete_permission', + 'transfer_ownership', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + 'kind', + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_permission_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.PermissionServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", + scopes=None, + default_scopes=( +), + quota_project_id="octopus", + ) + + +def test_permission_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.PermissionServiceTransport() + adc.assert_called_once() + + +def test_permission_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + PermissionServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=( +), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) +def test_permission_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + transports.PermissionServiceRestTransport, + ], +) +def test_permission_service_transport_auth_gdch_credentials(transport_class): + host = 'https://language.com' + api_audience_tests = [None, 'https://language2.com'] + api_audience_expect = [host, 'https://language2.com'] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with( + e + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.PermissionServiceGrpcTransport, grpc_helpers), + (transports.PermissionServiceGrpcAsyncIOTransport, grpc_helpers_async) + ], +) +def test_permission_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class( + quota_project_id="octopus", + scopes=["1", "2"] + ) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( +), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("transport_class", [transports.PermissionServiceGrpcTransport, transports.PermissionServiceGrpcAsyncIOTransport]) +def test_permission_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + +def test_permission_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.PermissionServiceRestTransport ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_permission_service_host_no_port(transport_name): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com' + ) + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_permission_service_host_with_port(transport_name): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:8000' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com:8000' + ) + +@pytest.mark.parametrize("transport_name", [ + "rest", +]) +def test_permission_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = PermissionServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = PermissionServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_permission._session + session2 = client2.transport.create_permission._session + assert session1 != session2 + session1 = client1.transport.get_permission._session + session2 = client2.transport.get_permission._session + assert session1 != session2 + session1 = client1.transport.list_permissions._session + session2 = client2.transport.list_permissions._session + assert session1 != session2 + session1 = client1.transport.update_permission._session + session2 = client2.transport.update_permission._session + assert session1 != session2 + session1 = client1.transport.delete_permission._session + session2 = client2.transport.delete_permission._session + assert session1 != session2 + session1 = client1.transport.transfer_ownership._session + session2 = client2.transport.transfer_ownership._session + assert session1 != session2 +def test_permission_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PermissionServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_permission_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PermissionServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.PermissionServiceGrpcTransport, transports.PermissionServiceGrpcAsyncIOTransport]) +def test_permission_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.PermissionServiceGrpcTransport, transports.PermissionServiceGrpcAsyncIOTransport]) +def test_permission_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_permission_path(): + tuned_model = "squid" + permission = "clam" + expected = "tunedModels/{tuned_model}/permissions/{permission}".format(tuned_model=tuned_model, permission=permission, ) + actual = PermissionServiceClient.permission_path(tuned_model, permission) + assert expected == actual + + +def test_parse_permission_path(): + expected = { + "tuned_model": "whelk", + "permission": "octopus", + } + path = PermissionServiceClient.permission_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_permission_path(path) + assert expected == actual + +def test_tuned_model_path(): + tuned_model = "oyster" + expected = "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + actual = PermissionServiceClient.tuned_model_path(tuned_model) + assert expected == actual + + +def test_parse_tuned_model_path(): + expected = { + "tuned_model": "nudibranch", + } + path = PermissionServiceClient.tuned_model_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_tuned_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = PermissionServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = PermissionServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "winkle" + expected = "folders/{folder}".format(folder=folder, ) + actual = PermissionServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = PermissionServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "scallop" + expected = "organizations/{organization}".format(organization=organization, ) + actual = PermissionServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = PermissionServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "squid" + expected = "projects/{project}".format(project=project, ) + actual = PermissionServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = PermissionServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "whelk" + location = "octopus" + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = PermissionServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = PermissionServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.PermissionServiceTransport, '_prep_wrapped_messages') as prep: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.PermissionServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = PermissionServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = PermissionServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with client: + close.assert_not_called() + close.assert_called_once() + +def test_client_ctx(): + transports = [ + 'rest', + 'grpc', + ] + for transport in transports: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + +@pytest.mark.parametrize("client_class,transport_class", [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport), + (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport), +]) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py new file mode 100644 index 000000000000..045435db1fc0 --- /dev/null +++ b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py @@ -0,0 +1,3177 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +import grpc +from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule +from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format + +from google.ai.generativelanguage_v1beta3.services.text_service import TextServiceAsyncClient +from google.ai.generativelanguage_v1beta3.services.text_service import TextServiceClient +from google.ai.generativelanguage_v1beta3.services.text_service import transports +from google.ai.generativelanguage_v1beta3.types import safety +from google.ai.generativelanguage_v1beta3.types import text_service +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import path_template +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +import google.auth + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert TextServiceClient._get_default_mtls_endpoint(None) is None + assert TextServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert TextServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +@pytest.mark.parametrize("client_class,transport_name", [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), +]) +def test_text_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +@pytest.mark.parametrize("transport_class,transport_name", [ + (transports.TextServiceGrpcTransport, "grpc"), + (transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.TextServiceRestTransport, "rest"), +]) +def test_text_service_client_service_account_always_use_jwt(transport_class, transport_name): + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize("client_class,transport_name", [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), +]) +def test_text_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else + 'https://generativelanguage.googleapis.com' + ) + + +def test_text_service_client_get_transport_class(): + transport = TextServiceClient.get_transport_class() + available_transports = [ + transports.TextServiceGrpcTransport, + transports.TextServiceRestTransport, + ] + assert transport in available_transports + + transport = TextServiceClient.get_transport_class("grpc") + assert transport == transports.TextServiceGrpcTransport + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), +]) +@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) +@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) +def test_text_service_client_client_options(client_class, transport_class, transport_name): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class(transport=transport_name) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with pytest.raises(ValueError): + client = client_class(transport=transport_name) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions(api_audience="https://language.googleapis.com") + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com" + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "true"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "false"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "true"), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "false"), +]) +@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) +@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_text_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): + with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ + TextServiceClient, TextServiceAsyncClient +]) +@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) +@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) +def test_text_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): + with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + +@pytest.mark.parametrize("client_class,transport_class,transport_name", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), +]) +def test_text_service_client_client_options_scopes(client_class, transport_class, transport_name): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), + (TextServiceClient, transports.TextServiceRestTransport, "rest", None), +]) +def test_text_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + +def test_text_service_client_client_options_from_dict(): + with mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceGrpcTransport.__init__') as grpc_transport: + grpc_transport.return_value = None + client = TextServiceClient( + client_options={'api_endpoint': 'squid.clam.whelk'} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), +]) +def test_text_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + # Check the case credentials file is provided. + options = client_options.ClientOptions( + credentials_file="credentials.json" + ) + + with mock.patch.object(transport_class, '__init__') as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=( +), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.GenerateTextRequest, + dict, +]) +def test_generate_text(request_type, transport: str = 'grpc'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse( + ) + response = client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.GenerateTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +def test_generate_text_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + client.generate_text() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.GenerateTextRequest() + +@pytest.mark.asyncio +async def test_generate_text_async(transport: str = 'grpc_asyncio', request_type=text_service.GenerateTextRequest): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse( + )) + response = await client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.GenerateTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +@pytest.mark.asyncio +async def test_generate_text_async_from_dict(): + await test_generate_text_async(request_type=dict) + + +def test_generate_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.GenerateTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + call.return_value = text_service.GenerateTextResponse() + client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_generate_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.GenerateTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) + await client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_generate_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_text( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text='text_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + arg = args[0].max_output_tokens + mock_val = 1865 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +def test_generate_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_text( + text_service.GenerateTextRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + +@pytest.mark.asyncio +async def test_generate_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.generate_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_text( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text='text_value') + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + arg = args[0].max_output_tokens + mock_val = 1865 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + +@pytest.mark.asyncio +async def test_generate_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_text( + text_service.GenerateTextRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.EmbedTextRequest, + dict, +]) +def test_embed_text(request_type, transport: str = 'grpc'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse( + ) + response = client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.EmbedTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +def test_embed_text_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + client.embed_text() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.EmbedTextRequest() + +@pytest.mark.asyncio +async def test_embed_text_async(transport: str = 'grpc_asyncio', request_type=text_service.EmbedTextRequest): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse( + )) + response = await client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.EmbedTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +@pytest.mark.asyncio +async def test_embed_text_async_from_dict(): + await test_embed_text_async(request_type=dict) + + +def test_embed_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.EmbedTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + call.return_value = text_service.EmbedTextResponse() + client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_embed_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.EmbedTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) + await client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_embed_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.embed_text( + model='model_value', + text='text_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].text + mock_val = 'text_value' + assert arg == mock_val + + +def test_embed_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_text( + text_service.EmbedTextRequest(), + model='model_value', + text='text_value', + ) + +@pytest.mark.asyncio +async def test_embed_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.embed_text( + model='model_value', + text='text_value', + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].text + mock_val = 'text_value' + assert arg == mock_val + +@pytest.mark.asyncio +async def test_embed_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.embed_text( + text_service.EmbedTextRequest(), + model='model_value', + text='text_value', + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.BatchEmbedTextRequest, + dict, +]) +def test_batch_embed_text(request_type, transport: str = 'grpc'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.BatchEmbedTextResponse( + ) + response = client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.BatchEmbedTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.BatchEmbedTextResponse) + + +def test_batch_embed_text_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_text), + '__call__') as call: + client.batch_embed_text() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.BatchEmbedTextRequest() + +@pytest.mark.asyncio +async def test_batch_embed_text_async(transport: str = 'grpc_asyncio', request_type=text_service.BatchEmbedTextRequest): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.BatchEmbedTextResponse( + )) + response = await client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.BatchEmbedTextRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.BatchEmbedTextResponse) + + +@pytest.mark.asyncio +async def test_batch_embed_text_async_from_dict(): + await test_batch_embed_text_async(request_type=dict) + + +def test_batch_embed_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.BatchEmbedTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_text), + '__call__') as call: + call.return_value = text_service.BatchEmbedTextResponse() + client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_batch_embed_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.BatchEmbedTextRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_text), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.BatchEmbedTextResponse()) + await client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_batch_embed_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.BatchEmbedTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_embed_text( + model='model_value', + texts=['texts_value'], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].texts + mock_val = ['texts_value'] + assert arg == mock_val + + +def test_batch_embed_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_embed_text( + text_service.BatchEmbedTextRequest(), + model='model_value', + texts=['texts_value'], + ) + +@pytest.mark.asyncio +async def test_batch_embed_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_text), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.BatchEmbedTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.BatchEmbedTextResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_embed_text( + model='model_value', + texts=['texts_value'], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].texts + mock_val = ['texts_value'] + assert arg == mock_val + +@pytest.mark.asyncio +async def test_batch_embed_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_embed_text( + text_service.BatchEmbedTextRequest(), + model='model_value', + texts=['texts_value'], + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.CountTextTokensRequest, + dict, +]) +def test_count_text_tokens(request_type, transport: str = 'grpc'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.CountTextTokensResponse( + token_count=1193, + ) + response = client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.CountTextTokensRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.CountTextTokensResponse) + assert response.token_count == 1193 + + +def test_count_text_tokens_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='grpc', + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), + '__call__') as call: + client.count_text_tokens() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.CountTextTokensRequest() + +@pytest.mark.asyncio +async def test_count_text_tokens_async(transport: str = 'grpc_asyncio', request_type=text_service.CountTextTokensRequest): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.CountTextTokensResponse( + token_count=1193, + )) + response = await client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.CountTextTokensRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.CountTextTokensResponse) + assert response.token_count == 1193 + + +@pytest.mark.asyncio +async def test_count_text_tokens_async_from_dict(): + await test_count_text_tokens_async(request_type=dict) + + +def test_count_text_tokens_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.CountTextTokensRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), + '__call__') as call: + call.return_value = text_service.CountTextTokensResponse() + client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +@pytest.mark.asyncio +async def test_count_text_tokens_field_headers_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.CountTextTokensRequest() + + request.model = 'model_value' + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), + '__call__') as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.CountTextTokensResponse()) + await client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + 'x-goog-request-params', + 'model=model_value', + ) in kw['metadata'] + + +def test_count_text_tokens_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.CountTextTokensResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.count_text_tokens( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text='text_value') + assert arg == mock_val + + +def test_count_text_tokens_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_text_tokens( + text_service.CountTextTokensRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + ) + +@pytest.mark.asyncio +async def test_count_text_tokens_flattened_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), + '__call__') as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.CountTextTokensResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.CountTextTokensResponse()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.count_text_tokens( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = 'model_value' + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text='text_value') + assert arg == mock_val + +@pytest.mark.asyncio +async def test_count_text_tokens_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.count_text_tokens( + text_service.CountTextTokensRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.GenerateTextRequest, + dict, +]) +def test_generate_text_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.generate_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +def test_generate_text_rest_required_fields(request_type=text_service.GenerateTextRequest): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.generate_text(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_generate_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.generate_text._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.TextServiceRestInterceptor, "post_generate_text") as post, \ + mock.patch.object(transports.TextServiceRestInterceptor, "pre_generate_text") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.GenerateTextRequest.pb(text_service.GenerateTextRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = text_service.GenerateTextResponse.to_json(text_service.GenerateTextResponse()) + + request = text_service.GenerateTextRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.GenerateTextResponse() + + client.generate_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_generate_text_rest_bad_request(transport: str = 'rest', request_type=text_service.GenerateTextRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.generate_text(request) + + +def test_generate_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.generate_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{model=models/*}:generateText" % client.transport._host, args[1]) + + +def test_generate_text_rest_flattened_error(transport: str = 'rest'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_text( + text_service.GenerateTextRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + +def test_generate_text_rest_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.EmbedTextRequest, + dict, +]) +def test_embed_text_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.embed_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +def test_embed_text_rest_required_fields(request_type=text_service.EmbedTextRequest): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request_init["text"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + jsonified_request["text"] = 'text_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + assert "text" in jsonified_request + assert jsonified_request["text"] == 'text_value' + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.embed_text(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_embed_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.embed_text._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "text", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_embed_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.TextServiceRestInterceptor, "post_embed_text") as post, \ + mock.patch.object(transports.TextServiceRestInterceptor, "pre_embed_text") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.EmbedTextRequest.pb(text_service.EmbedTextRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = text_service.EmbedTextResponse.to_json(text_service.EmbedTextResponse()) + + request = text_service.EmbedTextRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.EmbedTextResponse() + + client.embed_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_embed_text_rest_bad_request(transport: str = 'rest', request_type=text_service.EmbedTextRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.embed_text(request) + + +def test_embed_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + text='text_value', + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.embed_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{model=models/*}:embedText" % client.transport._host, args[1]) + + +def test_embed_text_rest_flattened_error(transport: str = 'rest'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_text( + text_service.EmbedTextRequest(), + model='model_value', + text='text_value', + ) + + +def test_embed_text_rest_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.BatchEmbedTextRequest, + dict, +]) +def test_batch_embed_text_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.BatchEmbedTextResponse( + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.BatchEmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.batch_embed_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.BatchEmbedTextResponse) + + +def test_batch_embed_text_rest_required_fields(request_type=text_service.BatchEmbedTextRequest): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request_init["texts"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).batch_embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + jsonified_request["texts"] = 'texts_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).batch_embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + assert "texts" in jsonified_request + assert jsonified_request["texts"] == 'texts_value' + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.BatchEmbedTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = text_service.BatchEmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.batch_embed_text(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_batch_embed_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.batch_embed_text._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "texts", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_batch_embed_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.TextServiceRestInterceptor, "post_batch_embed_text") as post, \ + mock.patch.object(transports.TextServiceRestInterceptor, "pre_batch_embed_text") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.BatchEmbedTextRequest.pb(text_service.BatchEmbedTextRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = text_service.BatchEmbedTextResponse.to_json(text_service.BatchEmbedTextResponse()) + + request = text_service.BatchEmbedTextRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.BatchEmbedTextResponse() + + client.batch_embed_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_batch_embed_text_rest_bad_request(transport: str = 'rest', request_type=text_service.BatchEmbedTextRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.batch_embed_text(request) + + +def test_batch_embed_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.BatchEmbedTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + texts=['texts_value'], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.BatchEmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.batch_embed_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{model=models/*}:batchEmbedText" % client.transport._host, args[1]) + + +def test_batch_embed_text_rest_flattened_error(transport: str = 'rest'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_embed_text( + text_service.BatchEmbedTextRequest(), + model='model_value', + texts=['texts_value'], + ) + + +def test_batch_embed_text_rest_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +@pytest.mark.parametrize("request_type", [ + text_service.CountTextTokensRequest, + dict, +]) +def test_count_text_tokens_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.CountTextTokensResponse( + token_count=1193, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.CountTextTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + response = client.count_text_tokens(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.CountTextTokensResponse) + assert response.token_count == 1193 + + +def test_count_text_tokens_rest_required_fields(request_type=text_service.CountTextTokensRequest): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads(json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False + )) + + # verify fields with default values are dropped + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_text_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = 'model_value' + + unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_text_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == 'model_value' + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest', + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.CountTextTokensResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, 'request') as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, 'transcode') as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + 'uri': 'v1/sample_method', + 'method': "post", + 'query_params': pb_request, + } + transcode_result['body'] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + pb_return_value = text_service.CountTextTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + response = client.count_text_tokens(request) + + expected_params = [ + ('$alt', 'json;enum-encoding=int') + ] + actual_params = req.call_args.kwargs['params'] + assert expected_params == actual_params + + +def test_count_text_tokens_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + + unset_fields = transport.count_text_tokens._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_count_text_tokens_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + mock.patch.object(transports.TextServiceRestInterceptor, "post_count_text_tokens") as post, \ + mock.patch.object(transports.TextServiceRestInterceptor, "pre_count_text_tokens") as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.CountTextTokensRequest.pb(text_service.CountTextTokensRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = text_service.CountTextTokensResponse.to_json(text_service.CountTextTokensResponse()) + + request = text_service.CountTextTokensRequest() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.CountTextTokensResponse() + + client.count_text_tokens(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + post.assert_called_once() + + +def test_count_text_tokens_rest_bad_request(transport: str = 'rest', request_type=text_service.CountTextTokensRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {'model': 'models/sample1'} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.count_text_tokens(request) + + +def test_count_text_tokens_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), 'request') as req: + # Designate an appropriate value for the returned response. + return_value = text_service.CountTextTokensResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {'model': 'models/sample1'} + + # get truthy value for each flattened field + mock_args = dict( + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + pb_return_value = text_service.CountTextTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(pb_return_value) + response_value._content = json_return_value.encode('UTF-8') + req.return_value = response_value + + client.count_text_tokens(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate("%s/v1beta3/{model=models/*}:countTextTokens" % client.transport._host, args[1]) + + +def test_count_text_tokens_rest_flattened_error(transport: str = 'rest'): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_text_tokens( + text_service.CountTextTokensRequest(), + model='model_value', + prompt=text_service.TextPrompt(text='text_value'), + ) + + +def test_count_text_tokens_rest_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport='rest' + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TextServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TextServiceClient( + client_options=options, + credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = TextServiceClient(transport=transport) + assert client.transport is transport + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.TextServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + +@pytest.mark.parametrize("transport_class", [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + transports.TextServiceRestTransport, +]) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "rest", +]) +def test_transport_kind(transport_name): + transport = TextServiceClient.get_transport_class(transport_name)( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert transport.kind == transport_name + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.TextServiceGrpcTransport, + ) + +def test_text_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.TextServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json" + ) + + +def test_text_service_base_transport(): + # Instantiate the base transport. + with mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport.__init__') as Transport: + Transport.return_value = None + transport = transports.TextServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + 'generate_text', + 'embed_text', + 'batch_embed_text', + 'count_text_tokens', + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + 'kind', + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_text_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TextServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with("credentials.json", + scopes=None, + default_scopes=( +), + quota_project_id="octopus", + ) + + +def test_text_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TextServiceTransport() + adc.assert_called_once() + + +def test_text_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + TextServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=( +), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + ], +) +def test_text_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + transports.TextServiceRestTransport, + ], +) +def test_text_service_transport_auth_gdch_credentials(transport_class): + host = 'https://language.com' + api_audience_tests = [None, 'https://language2.com'] + api_audience_expect = [host, 'https://language2.com'] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, 'default', autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with( + e + ) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.TextServiceGrpcTransport, grpc_helpers), + (transports.TextServiceGrpcAsyncIOTransport, grpc_helpers_async) + ], +) +def test_text_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class( + quota_project_id="octopus", + scopes=["1", "2"] + ) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=( +), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) +def test_text_service_grpc_transport_client_cert_source_for_mtls( + transport_class +): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, + private_key=expected_key + ) + +def test_text_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: + transports.TextServiceRestTransport ( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_text_service_host_no_port(transport_name): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:443' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com' + ) + +@pytest.mark.parametrize("transport_name", [ + "grpc", + "grpc_asyncio", + "rest", +]) +def test_text_service_host_with_port(transport_name): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + transport=transport_name, + ) + assert client.transport._host == ( + 'generativelanguage.googleapis.com:8000' + if transport_name in ['grpc', 'grpc_asyncio'] + else 'https://generativelanguage.googleapis.com:8000' + ) + +@pytest.mark.parametrize("transport_name", [ + "rest", +]) +def test_text_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = TextServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = TextServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.generate_text._session + session2 = client2.transport.generate_text._session + assert session1 != session2 + session1 = client1.transport.embed_text._session + session2 = client2.transport.embed_text._session + assert session1 != session2 + session1 = client1.transport.batch_embed_text._session + session2 = client2.transport.batch_embed_text._session + assert session1 != session2 + session1 = client1.transport.count_text_tokens._session + session2 = client2.transport.count_text_tokens._session + assert session1 != session2 +def test_text_service_grpc_transport_channel(): + channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TextServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_text_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TextServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) +def test_text_service_transport_channel_mtls_with_client_cert_source( + transport_class +): + with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, 'default') as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) +def test_text_service_transport_channel_mtls_with_adc( + transport_class +): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format(model=model, ) + actual = TextServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = TextServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_model_path(path) + assert expected == actual + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + actual = TextServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = TextServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_billing_account_path(path) + assert expected == actual + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format(folder=folder, ) + actual = TextServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = TextServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_folder_path(path) + assert expected == actual + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format(organization=organization, ) + actual = TextServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = TextServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_organization_path(path) + assert expected == actual + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format(project=project, ) + actual = TextServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = TextServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_project_path(path) + assert expected == actual + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + actual = TextServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = TextServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: + transport_class = TextServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + +@pytest.mark.asyncio +async def test_transport_close_async(): + client = TextServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close(): + transports = { + "rest": "_session", + "grpc": "_grpc_channel", + } + + for transport, close_name in transports.items(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with client: + close.assert_not_called() + close.assert_called_once() + +def test_client_ctx(): + transports = [ + 'rest', + 'grpc', + ] + for transport in transports: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + +@pytest.mark.parametrize("client_class,transport_class", [ + (TextServiceClient, transports.TextServiceGrpcTransport), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport), +]) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) From aecc2fb82b0820bed6819930fc5c375dd1565c36 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 19 Sep 2023 21:53:19 +0000 Subject: [PATCH 2/3] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20po?= =?UTF-8?q?st-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- .../v1beta2/.coveragerc | 13 - .../v1beta2/.flake8 | 33 - .../v1beta2/MANIFEST.in | 2 - .../v1beta2/README.rst | 49 - .../v1beta2/docs/_static/custom.css | 3 - .../v1beta2/docs/conf.py | 376 --- .../discuss_service.rst | 6 - .../model_service.rst | 10 - .../generativelanguage_v1beta2/services.rst | 8 - .../text_service.rst | 6 - .../docs/generativelanguage_v1beta2/types.rst | 6 - .../v1beta2/docs/index.rst | 7 - .../google/ai/generativelanguage/__init__.py | 85 - .../ai/generativelanguage_v1beta2/__init__.py | 86 - .../gapic_metadata.json | 156 - .../gapic_version.py | 16 - .../ai/generativelanguage_v1beta2/py.typed | 2 - .../services/discuss_service/__init__.py | 22 - .../services/discuss_service/async_client.py | 508 --- .../services/discuss_service/client.py | 712 ----- .../discuss_service/transports/base.py | 161 - .../discuss_service/transports/grpc.py | 295 -- .../transports/grpc_asyncio.py | 294 -- .../discuss_service/transports/rest.py | 433 --- .../services/model_service/__init__.py | 22 - .../services/model_service/async_client.py | 431 --- .../services/model_service/client.py | 635 ---- .../services/model_service/pagers.py | 140 - .../services/model_service/transports/base.py | 162 - .../services/model_service/transports/grpc.py | 292 -- .../model_service/transports/grpc_asyncio.py | 291 -- .../services/model_service/transports/rest.py | 397 --- .../services/text_service/async_client.py | 514 --- .../services/text_service/client.py | 718 ----- .../services/text_service/transports/base.py | 161 - .../services/text_service/transports/grpc.py | 295 -- .../text_service/transports/grpc_asyncio.py | 294 -- .../services/text_service/transports/rest.py | 423 --- .../types/__init__.py | 80 - .../types/citation.py | 102 - .../types/discuss_service.py | 358 --- .../generativelanguage_v1beta2/types/model.py | 156 - .../types/model_service.py | 114 - .../types/safety.py | 247 -- .../types/text_service.py | 333 -- .../v1beta2/mypy.ini | 3 - .../v1beta2/noxfile.py | 184 -- ...cuss_service_count_message_tokens_async.py | 56 - ...scuss_service_count_message_tokens_sync.py | 56 - ..._discuss_service_generate_message_async.py | 56 - ...d_discuss_service_generate_message_sync.py | 56 - ...generated_model_service_get_model_async.py | 52 - ..._generated_model_service_get_model_sync.py | 52 - ...nerated_model_service_list_models_async.py | 52 - ...enerated_model_service_list_models_sync.py | 52 - ...generated_text_service_embed_text_async.py | 53 - ..._generated_text_service_embed_text_sync.py | 53 - ...erated_text_service_generate_text_async.py | 56 - ...nerated_text_service_generate_text_sync.py | 56 - ..._google.ai.generativelanguage.v1beta2.json | 1093 ------- ...xup_generativelanguage_v1beta2_keywords.py | 181 -- .../v1beta2/setup.py | 90 - .../v1beta2/testing/constraints-3.10.txt | 6 - .../v1beta2/testing/constraints-3.11.txt | 6 - .../v1beta2/testing/constraints-3.12.txt | 6 - .../v1beta2/testing/constraints-3.7.txt | 9 - .../v1beta2/testing/constraints-3.8.txt | 6 - .../v1beta2/testing/constraints-3.9.txt | 6 - .../v1beta2/tests/__init__.py | 16 - .../v1beta2/tests/unit/__init__.py | 16 - .../v1beta2/tests/unit/gapic/__init__.py | 16 - .../generativelanguage_v1beta2/__init__.py | 16 - .../test_discuss_service.py | 2205 ------------- .../test_model_service.py | 2319 -------------- .../test_text_service.py | 2214 ------------- .../v1beta3/.coveragerc | 13 - .../v1beta3/.flake8 | 33 - .../v1beta3/MANIFEST.in | 2 - .../v1beta3/README.rst | 49 - .../v1beta3/docs/_static/custom.css | 3 - .../v1beta3/docs/conf.py | 376 --- .../v1beta3/docs/index.rst | 7 - .../google/ai/generativelanguage/__init__.py | 145 - .../ai/generativelanguage/gapic_version.py | 16 - .../google/ai/generativelanguage/py.typed | 2 - .../ai/generativelanguage_v1beta3/__init__.py | 146 - .../gapic_version.py | 16 - .../ai/generativelanguage_v1beta3/py.typed | 2 - .../discuss_service/transports/__init__.py | 38 - .../model_service/transports/__init__.py | 38 - .../services/text_service/__init__.py | 22 - .../text_service/transports/__init__.py | 38 - .../v1beta3/mypy.ini | 3 - .../v1beta3/noxfile.py | 184 -- .../v1beta3/setup.py | 90 - .../v1beta3/testing/constraints-3.10.txt | 6 - .../v1beta3/testing/constraints-3.11.txt | 6 - .../v1beta3/testing/constraints-3.12.txt | 6 - .../v1beta3/testing/constraints-3.7.txt | 9 - .../v1beta3/testing/constraints-3.8.txt | 6 - .../v1beta3/testing/constraints-3.9.txt | 6 - .../v1beta3/tests/__init__.py | 16 - .../v1beta3/tests/unit/__init__.py | 16 - .../v1beta3/tests/unit/gapic/__init__.py | 16 - .../generativelanguage_v1beta3/__init__.py | 16 - packages/google-ai-generativelanguage/.flake8 | 2 +- .../CONTRIBUTING.rst | 2 +- .../google-ai-generativelanguage/MANIFEST.in | 2 +- .../google-ai-generativelanguage/README.rst | 27 +- .../google-ai-generativelanguage/docs/conf.py | 2 +- .../discuss_service.rst | 0 .../model_service.rst | 0 .../permission_service.rst | 0 .../generativelanguage_v1beta3/services.rst | 0 .../text_service.rst | 0 .../docs/generativelanguage_v1beta3/types.rst | 0 .../docs/index.rst | 11 + .../services/discuss_service/async_client.py | 1 + .../services/discuss_service/client.py | 1 + .../discuss_service/transports/rest.py | 1 + .../types/discuss_service.py | 4 + .../generativelanguage_v1beta2/types/model.py | 1 + .../types/safety.py | 5 + .../types/text_service.py | 2 + .../ai/generativelanguage_v1beta3/__init__.py | 155 + .../gapic_metadata.json | 0 .../gapic_version.py | 0 .../ai/generativelanguage_v1beta3}/py.typed | 0 .../services/__init__.py | 0 .../services/discuss_service/__init__.py | 6 +- .../services/discuss_service/async_client.py | 167 +- .../services/discuss_service/client.py | 258 +- .../discuss_service/transports/__init__.py | 20 +- .../discuss_service/transports/base.py | 112 +- .../discuss_service/transports/grpc.py | 105 +- .../transports/grpc_asyncio.py | 103 +- .../discuss_service/transports/rest.py | 242 +- .../services/model_service/__init__.py | 6 +- .../services/model_service/async_client.py | 280 +- .../services/model_service/client.py | 389 ++- .../services/model_service/pagers.py | 85 +- .../model_service/transports/__init__.py | 20 +- .../services/model_service/transports/base.py | 181 +- .../services/model_service/transports/grpc.py | 181 +- .../model_service/transports/grpc_asyncio.py | 186 +- .../services/model_service/transports/rest.py | 614 ++-- .../services/permission_service/__init__.py | 6 +- .../permission_service/async_client.py | 270 +- .../services/permission_service/client.py | 381 ++- .../services/permission_service/pagers.py | 49 +- .../permission_service/transports/__init__.py | 20 +- .../permission_service/transports/base.py | 164 +- .../permission_service/transports/grpc.py | 169 +- .../transports/grpc_asyncio.py | 173 +- .../permission_service/transports/rest.py | 559 ++-- .../services/text_service/__init__.py | 6 +- .../services/text_service/async_client.py | 225 +- .../services/text_service/client.py | 324 +- .../text_service/transports/__init__.py | 20 +- .../services/text_service/transports/base.py | 141 +- .../services/text_service/transports/grpc.py | 133 +- .../text_service/transports/grpc_asyncio.py | 135 +- .../services/text_service/transports/rest.py | 410 ++- .../types/__init__.py | 121 +- .../types/citation.py | 11 +- .../types/discuss_service.py | 52 +- .../generativelanguage_v1beta3/types/model.py | 5 +- .../types/model_service.py | 27 +- .../types/permission.py | 6 +- .../types/permission_service.py | 24 +- .../types/safety.py | 32 +- .../types/text_service.py | 48 +- .../types/tuned_model.py | 59 +- .../google-ai-generativelanguage/noxfile.py | 2 +- ...cuss_service_count_message_tokens_async.py | 0 ...scuss_service_count_message_tokens_sync.py | 0 ..._discuss_service_generate_message_async.py | 0 ...d_discuss_service_generate_message_sync.py | 0 ..._model_service_create_tuned_model_async.py | 0 ...d_model_service_create_tuned_model_sync.py | 0 ..._model_service_delete_tuned_model_async.py | 0 ...d_model_service_delete_tuned_model_sync.py | 0 ...generated_model_service_get_model_async.py | 0 ..._generated_model_service_get_model_sync.py | 0 ...ted_model_service_get_tuned_model_async.py | 0 ...ated_model_service_get_tuned_model_sync.py | 0 ...nerated_model_service_list_models_async.py | 0 ...enerated_model_service_list_models_sync.py | 0 ...d_model_service_list_tuned_models_async.py | 0 ...ed_model_service_list_tuned_models_sync.py | 0 ..._model_service_update_tuned_model_async.py | 0 ...d_model_service_update_tuned_model_sync.py | 0 ...mission_service_create_permission_async.py | 0 ...rmission_service_create_permission_sync.py | 0 ...mission_service_delete_permission_async.py | 0 ...rmission_service_delete_permission_sync.py | 0 ...permission_service_get_permission_async.py | 0 ..._permission_service_get_permission_sync.py | 0 ...rmission_service_list_permissions_async.py | 0 ...ermission_service_list_permissions_sync.py | 0 ...ission_service_transfer_ownership_async.py | 0 ...mission_service_transfer_ownership_sync.py | 0 ...mission_service_update_permission_async.py | 0 ...rmission_service_update_permission_sync.py | 0 ...ted_text_service_batch_embed_text_async.py | 0 ...ated_text_service_batch_embed_text_sync.py | 0 ...ed_text_service_count_text_tokens_async.py | 0 ...ted_text_service_count_text_tokens_sync.py | 0 ...generated_text_service_embed_text_async.py | 0 ..._generated_text_service_embed_text_sync.py | 0 ...erated_text_service_generate_text_async.py | 0 ...nerated_text_service_generate_text_sync.py | 0 ..._google.ai.generativelanguage.v1beta3.json | 0 .../scripts/decrypt-secrets.sh | 2 +- ...xup_generativelanguage_v1beta3_keywords.py | 0 .../generativelanguage_v1beta3}/__init__.py | 0 .../test_discuss_service.py | 1281 +++++--- .../test_model_service.py | 2788 +++++++++++------ .../test_permission_service.py | 2461 +++++++++------ .../test_text_service.py | 1767 +++++++---- 220 files changed, 9075 insertions(+), 25789 deletions(-) delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py delete mode 100644 owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/docs/generativelanguage_v1beta3/discuss_service.rst (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/docs/generativelanguage_v1beta3/model_service.rst (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/docs/generativelanguage_v1beta3/permission_service.rst (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/docs/generativelanguage_v1beta3/services.rst (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/docs/generativelanguage_v1beta3/text_service.rst (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/docs/generativelanguage_v1beta3/types.rst (100%) create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/__init__.py rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/gapic_metadata.json (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage => packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3}/gapic_version.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage => packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3}/py.typed (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2 => packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3}/services/__init__.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py (92%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py (83%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py (82%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2 => packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3}/services/discuss_service/transports/__init__.py (67%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py (65%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py (81%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py (81%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py (73%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py (92%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py (85%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/client.py (83%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py (84%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2 => packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3}/services/model_service/transports/__init__.py (68%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py (62%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py (77%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py (77%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py (69%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py (91%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py (84%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/client.py (82%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py (84%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py (66%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py (63%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py (77%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py (77%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py (70%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2 => packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3}/services/text_service/__init__.py (92%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py (84%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/text_service/client.py (82%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2 => packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3}/services/text_service/transports/__init__.py (68%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py (63%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py (79%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py (79%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py (70%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/__init__.py (56%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/citation.py (92%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/discuss_service.py (91%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/model.py (98%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/model_service.py (95%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/permission.py (98%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/permission_service.py (93%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/safety.py (94%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/text_service.py (93%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/google/ai/generativelanguage_v1beta3/types/tuned_model.py (92%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/scripts/fixup_generativelanguage_v1beta3_keywords.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services => packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3}/__init__.py (100%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py (70%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py (68%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py (69%) rename {owl-bot-staging/google-ai-generativelanguage/v1beta3 => packages/google-ai-generativelanguage}/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py (70%) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc b/owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc deleted file mode 100644 index fd060ae956b5..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/.coveragerc +++ /dev/null @@ -1,13 +0,0 @@ -[run] -branch = True - -[report] -show_missing = True -omit = - google/ai/generativelanguage/__init__.py - google/ai/generativelanguage/gapic_version.py -exclude_lines = - # Re-enable the standard pragma - pragma: NO COVER - # Ignore debug-only repr - def __repr__ diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 b/owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 deleted file mode 100644 index 29227d4cf419..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/.flake8 +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Generated by synthtool. DO NOT EDIT! -[flake8] -ignore = E203, E266, E501, W503 -exclude = - # Exclude generated code. - **/proto/** - **/gapic/** - **/services/** - **/types/** - *_pb2.py - - # Standard linting exemptions. - **/.nox/** - __pycache__, - .git, - *.pyc, - conf.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in b/owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in deleted file mode 100644 index 27e3433a8451..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -recursive-include google/ai/generativelanguage *.py -recursive-include google/ai/generativelanguage_v1beta2 *.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst deleted file mode 100644 index 099f73894711..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/README.rst +++ /dev/null @@ -1,49 +0,0 @@ -Python Client for Google Ai Generativelanguage API -================================================= - -Quick Start ------------ - -In order to use this library, you first need to go through the following steps: - -1. `Select or create a Cloud Platform project.`_ -2. `Enable billing for your project.`_ -3. Enable the Google Ai Generativelanguage API. -4. `Setup Authentication.`_ - -.. _Select or create a Cloud Platform project.: https://console.cloud.google.com/project -.. _Enable billing for your project.: https://cloud.google.com/billing/docs/how-to/modify-project#enable_billing_for_a_project -.. _Setup Authentication.: https://googleapis.dev/python/google-api-core/latest/auth.html - -Installation -~~~~~~~~~~~~ - -Install this library in a `virtualenv`_ using pip. `virtualenv`_ is a tool to -create isolated Python environments. The basic problem it addresses is one of -dependencies and versions, and indirectly permissions. - -With `virtualenv`_, it's possible to install this library without needing system -install permissions, and without clashing with the installed system -dependencies. - -.. _`virtualenv`: https://virtualenv.pypa.io/en/latest/ - - -Mac/Linux -^^^^^^^^^ - -.. code-block:: console - - python3 -m venv - source /bin/activate - /bin/pip install /path/to/library - - -Windows -^^^^^^^ - -.. code-block:: console - - python3 -m venv - \Scripts\activate - \Scripts\pip.exe install \path\to\library diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css deleted file mode 100644 index 06423be0b592..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/_static/custom.css +++ /dev/null @@ -1,3 +0,0 @@ -dl.field-list > dt { - min-width: 100px -} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py deleted file mode 100644 index 0f3f4903ff54..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/conf.py +++ /dev/null @@ -1,376 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -# google-ai-generativelanguage documentation build configuration file -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import sys -import os -import shlex - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath("..")) - -__version__ = "0.1.0" - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = "4.0.1" - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.autosummary", - "sphinx.ext.intersphinx", - "sphinx.ext.coverage", - "sphinx.ext.napoleon", - "sphinx.ext.todo", - "sphinx.ext.viewcode", -] - -# autodoc/autosummary flags -autoclass_content = "both" -autodoc_default_flags = ["members"] -autosummary_generate = True - - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# Allow markdown includes (so releases.md can include CHANGLEOG.md) -# http://www.sphinx-doc.org/en/master/markdown.html -source_parsers = {".md": "recommonmark.parser.CommonMarkParser"} - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -source_suffix = [".rst", ".md"] - -# The encoding of source files. -# source_encoding = 'utf-8-sig' - -# The root toctree document. -root_doc = "index" - -# General information about the project. -project = u"google-ai-generativelanguage" -copyright = u"2023, Google, LLC" -author = u"Google APIs" # TODO: autogenerate this bit - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The full version, including alpha/beta/rc tags. -release = __version__ -# The short X.Y version. -version = ".".join(release.split(".")[0:2]) - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = 'en' - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -# today = '' -# Else, today_fmt is used as the format for a strftime call. -# today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = ["_build"] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -# default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -# add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -# add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -# show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# A list of ignored prefixes for module index sorting. -# modindex_common_prefix = [] - -# If true, keep warnings as "system message" paragraphs in the built documents. -# keep_warnings = False - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = True - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = "alabaster" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -html_theme_options = { - "description": "Google Ai Client Libraries for Python", - "github_user": "googleapis", - "github_repo": "google-cloud-python", - "github_banner": True, - "font_family": "'Roboto', Georgia, sans", - "head_font_family": "'Roboto', Georgia, serif", - "code_font_family": "'Roboto Mono', 'Consolas', monospace", -} - -# Add any paths that contain custom themes here, relative to this directory. -# html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -# html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -# html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -# html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -# html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -# html_extra_path = [] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -# html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -# html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -# html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -# html_additional_pages = {} - -# If false, no module index is generated. -# html_domain_indices = True - -# If false, no index is generated. -# html_use_index = True - -# If true, the index is split into individual pages for each letter. -# html_split_index = False - -# If true, links to the reST sources are added to the pages. -# html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -# html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -# html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -# html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -# html_file_suffix = None - -# Language to be used for generating the HTML full-text search index. -# Sphinx supports the following languages: -# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' -# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -# html_search_language = 'en' - -# A dictionary with options for the search language support, empty by default. -# Now only 'ja' uses this config value -# html_search_options = {'type': 'default'} - -# The name of a javascript file (relative to the configuration directory) that -# implements a search results scorer. If empty, the default will be used. -# html_search_scorer = 'scorer.js' - -# Output file base name for HTML help builder. -htmlhelp_basename = "google-ai-generativelanguage-doc" - -# -- Options for warnings ------------------------------------------------------ - - -suppress_warnings = [ - # Temporarily suppress this to avoid "more than one target found for - # cross-reference" warning, which are intractable for us to avoid while in - # a mono-repo. - # See https://github.com/sphinx-doc/sphinx/blob - # /2a65ffeef5c107c19084fabdd706cdff3f52d93c/sphinx/domains/python.py#L843 - "ref.python" -] - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # 'preamble': '', - # Latex figure (float) alignment - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - ( - root_doc, - "google-ai-generativelanguage.tex", - u"google-ai-generativelanguage Documentation", - author, - "manual", - ) -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -# latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -# latex_use_parts = False - -# If true, show page references after internal links. -# latex_show_pagerefs = False - -# If true, show URL addresses after external links. -# latex_show_urls = False - -# Documents to append as an appendix to all manuals. -# latex_appendices = [] - -# If false, no module index is generated. -# latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - ( - root_doc, - "google-ai-generativelanguage", - u"Google Ai Generativelanguage Documentation", - [author], - 1, - ) -] - -# If true, show URL addresses after external links. -# man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ( - root_doc, - "google-ai-generativelanguage", - u"google-ai-generativelanguage Documentation", - author, - "google-ai-generativelanguage", - "GAPIC library for Google Ai Generativelanguage API", - "APIs", - ) -] - -# Documents to append as an appendix to all manuals. -# texinfo_appendices = [] - -# If false, no module index is generated. -# texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -# texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -# texinfo_no_detailmenu = False - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - "python": ("http://python.readthedocs.org/en/latest/", None), - "gax": ("https://gax-python.readthedocs.org/en/latest/", None), - "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), - "google-gax": ("https://gax-python.readthedocs.io/en/latest/", None), - "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None), - "grpc": ("https://grpc.io/grpc/python/", None), - "requests": ("http://requests.kennethreitz.org/en/stable/", None), - "proto": ("https://proto-plus-python.readthedocs.io/en/stable", None), - "protobuf": ("https://googleapis.dev/python/protobuf/latest/", None), -} - - -# Napoleon settings -napoleon_google_docstring = True -napoleon_numpy_docstring = True -napoleon_include_private_with_doc = False -napoleon_include_special_with_doc = True -napoleon_use_admonition_for_examples = False -napoleon_use_admonition_for_notes = False -napoleon_use_admonition_for_references = False -napoleon_use_ivar = False -napoleon_use_param = True -napoleon_use_rtype = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst deleted file mode 100644 index be72af9f8e59..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/discuss_service.rst +++ /dev/null @@ -1,6 +0,0 @@ -DiscussService --------------------------------- - -.. automodule:: google.ai.generativelanguage_v1beta2.services.discuss_service - :members: - :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst deleted file mode 100644 index 7edf8f7f17c5..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/model_service.rst +++ /dev/null @@ -1,10 +0,0 @@ -ModelService ------------------------------- - -.. automodule:: google.ai.generativelanguage_v1beta2.services.model_service - :members: - :inherited-members: - -.. automodule:: google.ai.generativelanguage_v1beta2.services.model_service.pagers - :members: - :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst deleted file mode 100644 index e9e01c10ac08..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/services.rst +++ /dev/null @@ -1,8 +0,0 @@ -Services for Google Ai Generativelanguage v1beta2 API -===================================================== -.. toctree:: - :maxdepth: 2 - - discuss_service - model_service - text_service diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst deleted file mode 100644 index f30551e0f177..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/text_service.rst +++ /dev/null @@ -1,6 +0,0 @@ -TextService ------------------------------ - -.. automodule:: google.ai.generativelanguage_v1beta2.services.text_service - :members: - :inherited-members: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst deleted file mode 100644 index 81b702c1c9e1..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/generativelanguage_v1beta2/types.rst +++ /dev/null @@ -1,6 +0,0 @@ -Types for Google Ai Generativelanguage v1beta2 API -================================================== - -.. automodule:: google.ai.generativelanguage_v1beta2.types - :members: - :show-inheritance: diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst deleted file mode 100644 index c5b70436ea18..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/docs/index.rst +++ /dev/null @@ -1,7 +0,0 @@ -API Reference -------------- -.. toctree:: - :maxdepth: 2 - - generativelanguage_v1beta2/services - generativelanguage_v1beta2/types diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py deleted file mode 100644 index 16becd33efb7..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/__init__.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from google.ai.generativelanguage import gapic_version as package_version - -__version__ = package_version.__version__ - - -from google.ai.generativelanguage_v1beta2.services.discuss_service.client import DiscussServiceClient -from google.ai.generativelanguage_v1beta2.services.discuss_service.async_client import DiscussServiceAsyncClient -from google.ai.generativelanguage_v1beta2.services.model_service.client import ModelServiceClient -from google.ai.generativelanguage_v1beta2.services.model_service.async_client import ModelServiceAsyncClient -from google.ai.generativelanguage_v1beta2.services.text_service.client import TextServiceClient -from google.ai.generativelanguage_v1beta2.services.text_service.async_client import TextServiceAsyncClient - -from google.ai.generativelanguage_v1beta2.types.citation import CitationMetadata -from google.ai.generativelanguage_v1beta2.types.citation import CitationSource -from google.ai.generativelanguage_v1beta2.types.discuss_service import CountMessageTokensRequest -from google.ai.generativelanguage_v1beta2.types.discuss_service import CountMessageTokensResponse -from google.ai.generativelanguage_v1beta2.types.discuss_service import Example -from google.ai.generativelanguage_v1beta2.types.discuss_service import GenerateMessageRequest -from google.ai.generativelanguage_v1beta2.types.discuss_service import GenerateMessageResponse -from google.ai.generativelanguage_v1beta2.types.discuss_service import Message -from google.ai.generativelanguage_v1beta2.types.discuss_service import MessagePrompt -from google.ai.generativelanguage_v1beta2.types.model import Model -from google.ai.generativelanguage_v1beta2.types.model_service import GetModelRequest -from google.ai.generativelanguage_v1beta2.types.model_service import ListModelsRequest -from google.ai.generativelanguage_v1beta2.types.model_service import ListModelsResponse -from google.ai.generativelanguage_v1beta2.types.safety import ContentFilter -from google.ai.generativelanguage_v1beta2.types.safety import SafetyFeedback -from google.ai.generativelanguage_v1beta2.types.safety import SafetyRating -from google.ai.generativelanguage_v1beta2.types.safety import SafetySetting -from google.ai.generativelanguage_v1beta2.types.safety import HarmCategory -from google.ai.generativelanguage_v1beta2.types.text_service import Embedding -from google.ai.generativelanguage_v1beta2.types.text_service import EmbedTextRequest -from google.ai.generativelanguage_v1beta2.types.text_service import EmbedTextResponse -from google.ai.generativelanguage_v1beta2.types.text_service import GenerateTextRequest -from google.ai.generativelanguage_v1beta2.types.text_service import GenerateTextResponse -from google.ai.generativelanguage_v1beta2.types.text_service import TextCompletion -from google.ai.generativelanguage_v1beta2.types.text_service import TextPrompt - -__all__ = ('DiscussServiceClient', - 'DiscussServiceAsyncClient', - 'ModelServiceClient', - 'ModelServiceAsyncClient', - 'TextServiceClient', - 'TextServiceAsyncClient', - 'CitationMetadata', - 'CitationSource', - 'CountMessageTokensRequest', - 'CountMessageTokensResponse', - 'Example', - 'GenerateMessageRequest', - 'GenerateMessageResponse', - 'Message', - 'MessagePrompt', - 'Model', - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'ContentFilter', - 'SafetyFeedback', - 'SafetyRating', - 'SafetySetting', - 'HarmCategory', - 'Embedding', - 'EmbedTextRequest', - 'EmbedTextResponse', - 'GenerateTextRequest', - 'GenerateTextResponse', - 'TextCompletion', - 'TextPrompt', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py deleted file mode 100644 index 06c40d65931f..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/__init__.py +++ /dev/null @@ -1,86 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -__version__ = package_version.__version__ - - -from .services.discuss_service import DiscussServiceClient -from .services.discuss_service import DiscussServiceAsyncClient -from .services.model_service import ModelServiceClient -from .services.model_service import ModelServiceAsyncClient -from .services.text_service import TextServiceClient -from .services.text_service import TextServiceAsyncClient - -from .types.citation import CitationMetadata -from .types.citation import CitationSource -from .types.discuss_service import CountMessageTokensRequest -from .types.discuss_service import CountMessageTokensResponse -from .types.discuss_service import Example -from .types.discuss_service import GenerateMessageRequest -from .types.discuss_service import GenerateMessageResponse -from .types.discuss_service import Message -from .types.discuss_service import MessagePrompt -from .types.model import Model -from .types.model_service import GetModelRequest -from .types.model_service import ListModelsRequest -from .types.model_service import ListModelsResponse -from .types.safety import ContentFilter -from .types.safety import SafetyFeedback -from .types.safety import SafetyRating -from .types.safety import SafetySetting -from .types.safety import HarmCategory -from .types.text_service import Embedding -from .types.text_service import EmbedTextRequest -from .types.text_service import EmbedTextResponse -from .types.text_service import GenerateTextRequest -from .types.text_service import GenerateTextResponse -from .types.text_service import TextCompletion -from .types.text_service import TextPrompt - -__all__ = ( - 'DiscussServiceAsyncClient', - 'ModelServiceAsyncClient', - 'TextServiceAsyncClient', -'CitationMetadata', -'CitationSource', -'ContentFilter', -'CountMessageTokensRequest', -'CountMessageTokensResponse', -'DiscussServiceClient', -'EmbedTextRequest', -'EmbedTextResponse', -'Embedding', -'Example', -'GenerateMessageRequest', -'GenerateMessageResponse', -'GenerateTextRequest', -'GenerateTextResponse', -'GetModelRequest', -'HarmCategory', -'ListModelsRequest', -'ListModelsResponse', -'Message', -'MessagePrompt', -'Model', -'ModelServiceClient', -'SafetyFeedback', -'SafetyRating', -'SafetySetting', -'TextCompletion', -'TextPrompt', -'TextServiceClient', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json deleted file mode 100644 index e4a6a33d7d90..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_metadata.json +++ /dev/null @@ -1,156 +0,0 @@ - { - "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", - "language": "python", - "libraryPackage": "google.ai.generativelanguage_v1beta2", - "protoPackage": "google.ai.generativelanguage.v1beta2", - "schema": "1.0", - "services": { - "DiscussService": { - "clients": { - "grpc": { - "libraryClient": "DiscussServiceClient", - "rpcs": { - "CountMessageTokens": { - "methods": [ - "count_message_tokens" - ] - }, - "GenerateMessage": { - "methods": [ - "generate_message" - ] - } - } - }, - "grpc-async": { - "libraryClient": "DiscussServiceAsyncClient", - "rpcs": { - "CountMessageTokens": { - "methods": [ - "count_message_tokens" - ] - }, - "GenerateMessage": { - "methods": [ - "generate_message" - ] - } - } - }, - "rest": { - "libraryClient": "DiscussServiceClient", - "rpcs": { - "CountMessageTokens": { - "methods": [ - "count_message_tokens" - ] - }, - "GenerateMessage": { - "methods": [ - "generate_message" - ] - } - } - } - } - }, - "ModelService": { - "clients": { - "grpc": { - "libraryClient": "ModelServiceClient", - "rpcs": { - "GetModel": { - "methods": [ - "get_model" - ] - }, - "ListModels": { - "methods": [ - "list_models" - ] - } - } - }, - "grpc-async": { - "libraryClient": "ModelServiceAsyncClient", - "rpcs": { - "GetModel": { - "methods": [ - "get_model" - ] - }, - "ListModels": { - "methods": [ - "list_models" - ] - } - } - }, - "rest": { - "libraryClient": "ModelServiceClient", - "rpcs": { - "GetModel": { - "methods": [ - "get_model" - ] - }, - "ListModels": { - "methods": [ - "list_models" - ] - } - } - } - } - }, - "TextService": { - "clients": { - "grpc": { - "libraryClient": "TextServiceClient", - "rpcs": { - "EmbedText": { - "methods": [ - "embed_text" - ] - }, - "GenerateText": { - "methods": [ - "generate_text" - ] - } - } - }, - "grpc-async": { - "libraryClient": "TextServiceAsyncClient", - "rpcs": { - "EmbedText": { - "methods": [ - "embed_text" - ] - }, - "GenerateText": { - "methods": [ - "generate_text" - ] - } - } - }, - "rest": { - "libraryClient": "TextServiceClient", - "rpcs": { - "EmbedText": { - "methods": [ - "embed_text" - ] - }, - "GenerateText": { - "methods": [ - "generate_text" - ] - } - } - } - } - } - } -} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py deleted file mode 100644 index 360a0d13ebdd..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/gapic_version.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -__version__ = "0.0.0" # {x-release-please-version} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed deleted file mode 100644 index 38773eee6363..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/py.typed +++ /dev/null @@ -1,2 +0,0 @@ -# Marker file for PEP 561. -# The google-ai-generativelanguage package uses inline types. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py deleted file mode 100644 index c5c6e8208269..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from .client import DiscussServiceClient -from .async_client import DiscussServiceAsyncClient - -__all__ = ( - 'DiscussServiceClient', - 'DiscussServiceAsyncClient', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py deleted file mode 100644 index b6fbb11900d2..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py +++ /dev/null @@ -1,508 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -import functools -import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -from google.api_core.client_options import ClientOptions -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - -from google.ai.generativelanguage_v1beta2.types import discuss_service -from google.ai.generativelanguage_v1beta2.types import safety -from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport -from .client import DiscussServiceClient - - -class DiscussServiceAsyncClient: - """An API for using Generative Language Models (GLMs) in dialog - applications. - Also known as large language models (LLMs), this API provides - models that are trained for multi-turn dialog. - """ - - _client: DiscussServiceClient - - DEFAULT_ENDPOINT = DiscussServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = DiscussServiceClient.DEFAULT_MTLS_ENDPOINT - - model_path = staticmethod(DiscussServiceClient.model_path) - parse_model_path = staticmethod(DiscussServiceClient.parse_model_path) - common_billing_account_path = staticmethod(DiscussServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(DiscussServiceClient.parse_common_billing_account_path) - common_folder_path = staticmethod(DiscussServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(DiscussServiceClient.parse_common_folder_path) - common_organization_path = staticmethod(DiscussServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(DiscussServiceClient.parse_common_organization_path) - common_project_path = staticmethod(DiscussServiceClient.common_project_path) - parse_common_project_path = staticmethod(DiscussServiceClient.parse_common_project_path) - common_location_path = staticmethod(DiscussServiceClient.common_location_path) - parse_common_location_path = staticmethod(DiscussServiceClient.parse_common_location_path) - - @classmethod - def from_service_account_info(cls, info: dict, *args, **kwargs): - """Creates an instance of this client using the provided credentials - info. - - Args: - info (dict): The service account private key info. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - DiscussServiceAsyncClient: The constructed client. - """ - return DiscussServiceClient.from_service_account_info.__func__(DiscussServiceAsyncClient, info, *args, **kwargs) # type: ignore - - @classmethod - def from_service_account_file(cls, filename: str, *args, **kwargs): - """Creates an instance of this client using the provided credentials - file. - - Args: - filename (str): The path to the service account private key json - file. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - DiscussServiceAsyncClient: The constructed client. - """ - return DiscussServiceClient.from_service_account_file.__func__(DiscussServiceAsyncClient, filename, *args, **kwargs) # type: ignore - - from_service_account_json = from_service_account_file - - @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): - """Return the API endpoint and client cert source for mutual TLS. - - The client cert source is determined in the following order: - (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the - client cert source is None. - (2) if `client_options.client_cert_source` is provided, use the provided one; if the - default client cert source exists, use the default one; otherwise the client cert - source is None. - - The API endpoint is determined in the following order: - (1) if `client_options.api_endpoint` if provided, use the provided one. - (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the - default mTLS endpoint; if the environment variable is "never", use the default API - endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise - use the default API endpoint. - - More details can be found at https://google.aip.dev/auth/4114. - - Args: - client_options (google.api_core.client_options.ClientOptions): Custom options for the - client. Only the `api_endpoint` and `client_cert_source` properties may be used - in this method. - - Returns: - Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the - client cert source to use. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If any errors happen. - """ - return DiscussServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore - - @property - def transport(self) -> DiscussServiceTransport: - """Returns the transport used by the client instance. - - Returns: - DiscussServiceTransport: The transport used by the client instance. - """ - return self._client.transport - - get_transport_class = functools.partial(type(DiscussServiceClient).get_transport_class, type(DiscussServiceClient)) - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DiscussServiceTransport] = "grpc_asyncio", - client_options: Optional[ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiates the discuss service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.DiscussServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - self._client = DiscussServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - - ) - - async def generate_message(self, - request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.GenerateMessageResponse: - r"""Generates a response from the model given an input - ``MessagePrompt``. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - async def sample_generate_message(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceAsyncClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.GenerateMessageRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = await client.generate_message(request=request) - - # Handle the response - print(response) - - Args: - request (Optional[Union[google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest, dict]]): - The request object. Request to generate a message - response from the model. - model (:class:`str`): - Required. The name of the model to use. - - Format: ``name=models/{model}``. - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - prompt (:class:`google.ai.generativelanguage_v1beta2.types.MessagePrompt`): - Required. The structured textual - input given to the model as a prompt. - Given a - prompt, the model will return what it - predicts is the next message in the - discussion. - - This corresponds to the ``prompt`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - temperature (:class:`float`): - Optional. Controls the randomness of the output. - - Values can range over ``[0.0,1.0]``, inclusive. A value - closer to ``1.0`` will produce responses that are more - varied, while a value closer to ``0.0`` will typically - result in less surprising responses from the model. - - This corresponds to the ``temperature`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - candidate_count (:class:`int`): - Optional. The number of generated response messages to - return. - - This value must be between ``[1, 8]``, inclusive. If - unset, this will default to ``1``. - - This corresponds to the ``candidate_count`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_p (:class:`float`): - Optional. The maximum cumulative probability of tokens - to consider when sampling. - - The model uses combined Top-k and nucleus sampling. - - Nucleus sampling considers the smallest set of tokens - whose probability sum is at least ``top_p``. - - This corresponds to the ``top_p`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_k (:class:`int`): - Optional. The maximum number of tokens to consider when - sampling. - - The model uses combined Top-k and nucleus sampling. - - Top-k sampling considers the set of ``top_k`` most - probable tokens. - - This corresponds to the ``top_k`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse: - The response from the model. - - This includes candidate messages and - conversation history in the form of - chronologically-ordered messages. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) - if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") - - request = discuss_service.GenerateMessageRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if prompt is not None: - request.prompt = prompt - if temperature is not None: - request.temperature = temperature - if candidate_count is not None: - request.candidate_count = candidate_count - if top_p is not None: - request.top_p = top_p - if top_k is not None: - request.top_k = top_k - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.generate_message, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - async def count_message_tokens(self, - request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.CountMessageTokensResponse: - r"""Runs a model's tokenizer on a string and returns the - token count. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - async def sample_count_message_tokens(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceAsyncClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.CountMessageTokensRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = await client.count_message_tokens(request=request) - - # Handle the response - print(response) - - Args: - request (Optional[Union[google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest, dict]]): - The request object. Counts the number of tokens in the ``prompt`` sent to a - model. - - Models may tokenize text differently, so each model may - return a different ``token_count``. - model (:class:`str`): - Required. The model's resource name. This serves as an - ID for the Model to use. - - This name should match a model name returned by the - ``ListModels`` method. - - Format: ``models/{model}`` - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - prompt (:class:`google.ai.generativelanguage_v1beta2.types.MessagePrompt`): - Required. The prompt, whose token - count is to be returned. - - This corresponds to the ``prompt`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse: - A response from CountMessageTokens. - - It returns the model's token_count for the prompt. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt]) - if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") - - request = discuss_service.CountMessageTokensRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if prompt is not None: - request.prompt = prompt - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.count_message_tokens, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - async def __aenter__(self) -> "DiscussServiceAsyncClient": - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.transport.close() - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "DiscussServiceAsyncClient", -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py deleted file mode 100644 index 6301bfd36a15..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py +++ /dev/null @@ -1,712 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -import os -import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -from google.api_core import client_options as client_options_lib -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - -from google.ai.generativelanguage_v1beta2.types import discuss_service -from google.ai.generativelanguage_v1beta2.types import safety -from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc import DiscussServiceGrpcTransport -from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport -from .transports.rest import DiscussServiceRestTransport - - -class DiscussServiceClientMeta(type): - """Metaclass for the DiscussService client. - - This provides class-level methods for building and retrieving - support objects (e.g. transport) without polluting the client instance - objects. - """ - _transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] - _transport_registry["grpc"] = DiscussServiceGrpcTransport - _transport_registry["grpc_asyncio"] = DiscussServiceGrpcAsyncIOTransport - _transport_registry["rest"] = DiscussServiceRestTransport - - def get_transport_class(cls, - label: Optional[str] = None, - ) -> Type[DiscussServiceTransport]: - """Returns an appropriate transport class. - - Args: - label: The name of the desired transport. If none is - provided, then the first transport in the registry is used. - - Returns: - The transport class to use. - """ - # If a specific transport is requested, return that one. - if label: - return cls._transport_registry[label] - - # No transport is requested; return the default (that is, the first one - # in the dictionary). - return next(iter(cls._transport_registry.values())) - - -class DiscussServiceClient(metaclass=DiscussServiceClientMeta): - """An API for using Generative Language Models (GLMs) in dialog - applications. - Also known as large language models (LLMs), this API provides - models that are trained for multi-turn dialog. - """ - - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Converts api endpoint to mTLS endpoint. - - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT - ) - - @classmethod - def from_service_account_info(cls, info: dict, *args, **kwargs): - """Creates an instance of this client using the provided credentials - info. - - Args: - info (dict): The service account private key info. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - DiscussServiceClient: The constructed client. - """ - credentials = service_account.Credentials.from_service_account_info(info) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) - - @classmethod - def from_service_account_file(cls, filename: str, *args, **kwargs): - """Creates an instance of this client using the provided credentials - file. - - Args: - filename (str): The path to the service account private key json - file. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - DiscussServiceClient: The constructed client. - """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) - - from_service_account_json = from_service_account_file - - @property - def transport(self) -> DiscussServiceTransport: - """Returns the transport used by the client instance. - - Returns: - DiscussServiceTransport: The transport used by the client - instance. - """ - return self._transport - - @staticmethod - def model_path(model: str,) -> str: - """Returns a fully-qualified model string.""" - return "models/{model}".format(model=model, ) - - @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: - """Parses a model path into its component segments.""" - m = re.match(r"^models/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: - """Returns a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) - - @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: - """Parse a billing_account path into its component segments.""" - m = re.match(r"^billingAccounts/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_folder_path(folder: str, ) -> str: - """Returns a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) - - @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: - """Parse a folder path into its component segments.""" - m = re.match(r"^folders/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_organization_path(organization: str, ) -> str: - """Returns a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) - - @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: - """Parse a organization path into its component segments.""" - m = re.match(r"^organizations/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_project_path(project: str, ) -> str: - """Returns a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) - - @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: - """Parse a project path into its component segments.""" - m = re.match(r"^projects/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_location_path(project: str, location: str, ) -> str: - """Returns a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) - - @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: - """Parse a location path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) - return m.groupdict() if m else {} - - @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): - """Return the API endpoint and client cert source for mutual TLS. - - The client cert source is determined in the following order: - (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the - client cert source is None. - (2) if `client_options.client_cert_source` is provided, use the provided one; if the - default client cert source exists, use the default one; otherwise the client cert - source is None. - - The API endpoint is determined in the following order: - (1) if `client_options.api_endpoint` if provided, use the provided one. - (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the - default mTLS endpoint; if the environment variable is "never", use the default API - endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise - use the default API endpoint. - - More details can be found at https://google.aip.dev/auth/4114. - - Args: - client_options (google.api_core.client_options.ClientOptions): Custom options for the - client. Only the `api_endpoint` and `client_cert_source` properties may be used - in this method. - - Returns: - Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the - client cert source to use. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If any errors happen. - """ - if client_options is None: - client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") - use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") - if use_mtls_endpoint not in ("auto", "never", "always"): - raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") - - # Figure out the client cert source to use. - client_cert_source = None - if use_client_cert == "true": - if client_options.client_cert_source: - client_cert_source = client_options.client_cert_source - elif mtls.has_default_client_cert_source(): - client_cert_source = mtls.default_client_cert_source() - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): - api_endpoint = cls.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = cls.DEFAULT_ENDPOINT - - return api_endpoint, client_cert_source - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DiscussServiceTransport]] = None, - client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiates the discuss service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, DiscussServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - """ - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) - - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) - - api_key_value = getattr(client_options, "api_key", None) - if api_key_value and credentials: - raise ValueError("client_options.api_key and credentials are mutually exclusive") - - # Save or instantiate the transport. - # Ordinarily, we provide the transport, but allowing a custom transport - # instance provides an extensibility point for unusual situations. - if isinstance(transport, DiscussServiceTransport): - # transport is a DiscussServiceTransport instance. - if credentials or client_options.credentials_file or api_key_value: - raise ValueError("When providing a transport instance, " - "provide its credentials directly.") - if client_options.scopes: - raise ValueError( - "When providing a transport instance, provide its scopes " - "directly." - ) - self._transport = transport - else: - import google.auth._default # type: ignore - - if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): - credentials = google.auth._default.get_api_key_credentials(api_key_value) - - Transport = type(self).get_transport_class(transport) - self._transport = Transport( - credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, - client_info=client_info, - always_use_jwt_access=True, - api_audience=client_options.api_audience, - ) - - def generate_message(self, - request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.GenerateMessageResponse: - r"""Generates a response from the model given an input - ``MessagePrompt``. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - def sample_generate_message(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.GenerateMessageRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = client.generate_message(request=request) - - # Handle the response - print(response) - - Args: - request (Union[google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest, dict]): - The request object. Request to generate a message - response from the model. - model (str): - Required. The name of the model to use. - - Format: ``name=models/{model}``. - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): - Required. The structured textual - input given to the model as a prompt. - Given a - prompt, the model will return what it - predicts is the next message in the - discussion. - - This corresponds to the ``prompt`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - temperature (float): - Optional. Controls the randomness of the output. - - Values can range over ``[0.0,1.0]``, inclusive. A value - closer to ``1.0`` will produce responses that are more - varied, while a value closer to ``0.0`` will typically - result in less surprising responses from the model. - - This corresponds to the ``temperature`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - candidate_count (int): - Optional. The number of generated response messages to - return. - - This value must be between ``[1, 8]``, inclusive. If - unset, this will default to ``1``. - - This corresponds to the ``candidate_count`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_p (float): - Optional. The maximum cumulative probability of tokens - to consider when sampling. - - The model uses combined Top-k and nucleus sampling. - - Nucleus sampling considers the smallest set of tokens - whose probability sum is at least ``top_p``. - - This corresponds to the ``top_p`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_k (int): - Optional. The maximum number of tokens to consider when - sampling. - - The model uses combined Top-k and nucleus sampling. - - Top-k sampling considers the set of ``top_k`` most - probable tokens. - - This corresponds to the ``top_k`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse: - The response from the model. - - This includes candidate messages and - conversation history in the form of - chronologically-ordered messages. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) - if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') - - # Minor optimization to avoid making a copy if the user passes - # in a discuss_service.GenerateMessageRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, discuss_service.GenerateMessageRequest): - request = discuss_service.GenerateMessageRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if prompt is not None: - request.prompt = prompt - if temperature is not None: - request.temperature = temperature - if candidate_count is not None: - request.candidate_count = candidate_count - if top_p is not None: - request.top_p = top_p - if top_k is not None: - request.top_k = top_k - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.generate_message] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - def count_message_tokens(self, - request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.CountMessageTokensResponse: - r"""Runs a model's tokenizer on a string and returns the - token count. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - def sample_count_message_tokens(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.CountMessageTokensRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = client.count_message_tokens(request=request) - - # Handle the response - print(response) - - Args: - request (Union[google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest, dict]): - The request object. Counts the number of tokens in the ``prompt`` sent to a - model. - - Models may tokenize text differently, so each model may - return a different ``token_count``. - model (str): - Required. The model's resource name. This serves as an - ID for the Model to use. - - This name should match a model name returned by the - ``ListModels`` method. - - Format: ``models/{model}`` - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): - Required. The prompt, whose token - count is to be returned. - - This corresponds to the ``prompt`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse: - A response from CountMessageTokens. - - It returns the model's token_count for the prompt. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt]) - if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') - - # Minor optimization to avoid making a copy if the user passes - # in a discuss_service.CountMessageTokensRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, discuss_service.CountMessageTokensRequest): - request = discuss_service.CountMessageTokensRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if prompt is not None: - request.prompt = prompt - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.count_message_tokens] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - def __enter__(self) -> "DiscussServiceClient": - return self - - def __exit__(self, type, value, traceback): - """Releases underlying transport's resources. - - .. warning:: - ONLY use as a context manager if the transport is NOT shared - with other clients! Exiting the with block will CLOSE the transport - and may cause errors in other clients! - """ - self.transport.close() - - - - - - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "DiscussServiceClient", -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py deleted file mode 100644 index c7d8455ba342..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/base.py +++ /dev/null @@ -1,161 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import abc -from typing import Awaitable, Callable, Dict, Optional, Sequence, Union - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -import google.auth # type: ignore -import google.api_core -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.ai.generativelanguage_v1beta2.types import discuss_service - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -class DiscussServiceTransport(abc.ABC): - """Abstract transport class for DiscussService.""" - - AUTH_SCOPES = ( - ) - - DEFAULT_HOST: str = 'generativelanguage.googleapis.com' - def __init__( - self, *, - host: str = DEFAULT_HOST, - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - **kwargs, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - """ - - scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} - - # Save the scopes. - self._scopes = scopes - - # If no credentials are provided, then determine the appropriate - # defaults. - if credentials and credentials_file: - raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") - - if credentials_file is not None: - credentials, _ = google.auth.load_credentials_from_file( - credentials_file, - **scopes_kwargs, - quota_project_id=quota_project_id - ) - elif credentials is None: - credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) - # Don't apply audience if the credentials file passed from user. - if hasattr(credentials, "with_gdch_audience"): - credentials = credentials.with_gdch_audience(api_audience if api_audience else host) - - # If the credentials are service account credentials, then always try to use self signed JWT. - if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): - credentials = credentials.with_always_use_jwt_access(True) - - # Save the credentials. - self._credentials = credentials - - # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' - self._host = host - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.generate_message: gapic_v1.method.wrap_method( - self.generate_message, - default_timeout=None, - client_info=client_info, - ), - self.count_message_tokens: gapic_v1.method.wrap_method( - self.count_message_tokens, - default_timeout=None, - client_info=client_info, - ), - } - - def close(self): - """Closes resources associated with the transport. - - .. warning:: - Only call this method if the transport is NOT shared - with other clients - this may cause errors in other clients! - """ - raise NotImplementedError() - - @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - Union[ - discuss_service.GenerateMessageResponse, - Awaitable[discuss_service.GenerateMessageResponse] - ]]: - raise NotImplementedError() - - @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - Union[ - discuss_service.CountMessageTokensResponse, - Awaitable[discuss_service.CountMessageTokensResponse] - ]]: - raise NotImplementedError() - - @property - def kind(self) -> str: - raise NotImplementedError() - - -__all__ = ( - 'DiscussServiceTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py deleted file mode 100644 index 7fc2d1e9779c..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc.py +++ /dev/null @@ -1,295 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple, Union - -from google.api_core import grpc_helpers -from google.api_core import gapic_v1 -import google.auth # type: ignore -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore - -from google.ai.generativelanguage_v1beta2.types import discuss_service -from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO - - -class DiscussServiceGrpcTransport(DiscussServiceTransport): - """gRPC backend transport for DiscussService. - - An API for using Generative Language Models (GLMs) in dialog - applications. - Also known as large language models (LLMs), this API provides - models that are trained for multi-turn dialog. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - _stubs: Dict[str, Callable] - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - self._grpc_channel = None - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if channel: - # Ignore credentials if a channel was passed. - credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - self._ssl_channel_credentials = None - - else: - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - - if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: - """Create and return a gRPC channel object. - Args: - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - - return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs - ) - - @property - def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ - return self._grpc_channel - - @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - discuss_service.GenerateMessageResponse]: - r"""Return a callable for the generate message method over gRPC. - - Generates a response from the model given an input - ``MessagePrompt``. - - Returns: - Callable[[~.GenerateMessageRequest], - ~.GenerateMessageResponse]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'generate_message' not in self._stubs: - self._stubs['generate_message'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.DiscussService/GenerateMessage', - request_serializer=discuss_service.GenerateMessageRequest.serialize, - response_deserializer=discuss_service.GenerateMessageResponse.deserialize, - ) - return self._stubs['generate_message'] - - @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - discuss_service.CountMessageTokensResponse]: - r"""Return a callable for the count message tokens method over gRPC. - - Runs a model's tokenizer on a string and returns the - token count. - - Returns: - Callable[[~.CountMessageTokensRequest], - ~.CountMessageTokensResponse]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'count_message_tokens' not in self._stubs: - self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.DiscussService/CountMessageTokens', - request_serializer=discuss_service.CountMessageTokensRequest.serialize, - response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, - ) - return self._stubs['count_message_tokens'] - - def close(self): - self.grpc_channel.close() - - @property - def kind(self) -> str: - return "grpc" - - -__all__ = ( - 'DiscussServiceGrpcTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py deleted file mode 100644 index 97e6d426fc5c..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,294 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union - -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.ai.generativelanguage_v1beta2.types import discuss_service -from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import DiscussServiceGrpcTransport - - -class DiscussServiceGrpcAsyncIOTransport(DiscussServiceTransport): - """gRPC AsyncIO backend transport for DiscussService. - - An API for using Generative Language Models (GLMs) in dialog - applications. - Also known as large language models (LLMs), this API provides - models that are trained for multi-turn dialog. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs - ) - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - self._grpc_channel = None - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if channel: - # Ignore credentials if a channel was passed. - credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - self._ssl_channel_credentials = None - else: - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - - if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - Awaitable[discuss_service.GenerateMessageResponse]]: - r"""Return a callable for the generate message method over gRPC. - - Generates a response from the model given an input - ``MessagePrompt``. - - Returns: - Callable[[~.GenerateMessageRequest], - Awaitable[~.GenerateMessageResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'generate_message' not in self._stubs: - self._stubs['generate_message'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.DiscussService/GenerateMessage', - request_serializer=discuss_service.GenerateMessageRequest.serialize, - response_deserializer=discuss_service.GenerateMessageResponse.deserialize, - ) - return self._stubs['generate_message'] - - @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - Awaitable[discuss_service.CountMessageTokensResponse]]: - r"""Return a callable for the count message tokens method over gRPC. - - Runs a model's tokenizer on a string and returns the - token count. - - Returns: - Callable[[~.CountMessageTokensRequest], - Awaitable[~.CountMessageTokensResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'count_message_tokens' not in self._stubs: - self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.DiscussService/CountMessageTokens', - request_serializer=discuss_service.CountMessageTokensRequest.serialize, - response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, - ) - return self._stubs['count_message_tokens'] - - def close(self): - return self.grpc_channel.close() - - -__all__ = ( - 'DiscussServiceGrpcAsyncIOTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py deleted file mode 100644 index fd68266db64d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py +++ /dev/null @@ -1,433 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from google.auth.transport.requests import AuthorizedSession # type: ignore -import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore -from google.api_core import exceptions as core_exceptions -from google.api_core import retry as retries -from google.api_core import rest_helpers -from google.api_core import rest_streaming -from google.api_core import path_template -from google.api_core import gapic_v1 - -from google.protobuf import json_format -from requests import __version__ as requests_version -import dataclasses -import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import warnings - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - - -from google.ai.generativelanguage_v1beta2.types import discuss_service - -from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, - grpc_version=None, - rest_version=requests_version, -) - - -class DiscussServiceRestInterceptor: - """Interceptor for DiscussService. - - Interceptors are used to manipulate requests, request metadata, and responses - in arbitrary ways. - Example use cases include: - * Logging - * Verifying requests according to service or custom semantics - * Stripping extraneous information from responses - - These use cases and more can be enabled by injecting an - instance of a custom subclass when constructing the DiscussServiceRestTransport. - - .. code-block:: python - class MyCustomDiscussServiceInterceptor(DiscussServiceRestInterceptor): - def pre_count_message_tokens(self, request, metadata): - logging.log(f"Received request: {request}") - return request, metadata - - def post_count_message_tokens(self, response): - logging.log(f"Received response: {response}") - return response - - def pre_generate_message(self, request, metadata): - logging.log(f"Received request: {request}") - return request, metadata - - def post_generate_message(self, response): - logging.log(f"Received response: {response}") - return response - - transport = DiscussServiceRestTransport(interceptor=MyCustomDiscussServiceInterceptor()) - client = DiscussServiceClient(transport=transport) - - - """ - def pre_count_message_tokens(self, request: discuss_service.CountMessageTokensRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.CountMessageTokensRequest, Sequence[Tuple[str, str]]]: - """Pre-rpc interceptor for count_message_tokens - - Override in a subclass to manipulate the request or metadata - before they are sent to the DiscussService server. - """ - return request, metadata - - def post_count_message_tokens(self, response: discuss_service.CountMessageTokensResponse) -> discuss_service.CountMessageTokensResponse: - """Post-rpc interceptor for count_message_tokens - - Override in a subclass to manipulate the response - after it is returned by the DiscussService server but before - it is returned to user code. - """ - return response - def pre_generate_message(self, request: discuss_service.GenerateMessageRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.GenerateMessageRequest, Sequence[Tuple[str, str]]]: - """Pre-rpc interceptor for generate_message - - Override in a subclass to manipulate the request or metadata - before they are sent to the DiscussService server. - """ - return request, metadata - - def post_generate_message(self, response: discuss_service.GenerateMessageResponse) -> discuss_service.GenerateMessageResponse: - """Post-rpc interceptor for generate_message - - Override in a subclass to manipulate the response - after it is returned by the DiscussService server but before - it is returned to user code. - """ - return response - - -@dataclasses.dataclass -class DiscussServiceRestStub: - _session: AuthorizedSession - _host: str - _interceptor: DiscussServiceRestInterceptor - - -class DiscussServiceRestTransport(DiscussServiceTransport): - """REST backend transport for DiscussService. - - An API for using Generative Language Models (GLMs) in dialog - applications. - Also known as large language models (LLMs), this API provides - models that are trained for multi-turn dialog. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends JSON representations of protocol buffers over HTTP/1.1 - - """ - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - client_cert_source_for_mtls: Optional[Callable[[ - ], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - url_scheme: str = 'https', - interceptor: Optional[DiscussServiceRestInterceptor] = None, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. - """ - # Run the base constructor - # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. - # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the - # credentials object - maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) - if maybe_url_match is None: - raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER - - url_match_items = maybe_url_match.groupdict() - - host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host - - super().__init__( - host=host, - credentials=credentials, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience - ) - self._session = AuthorizedSession( - self._credentials, default_host=self.DEFAULT_HOST) - if client_cert_source_for_mtls: - self._session.configure_mtls_channel(client_cert_source_for_mtls) - self._interceptor = interceptor or DiscussServiceRestInterceptor() - self._prep_wrapped_messages(client_info) - - class _CountMessageTokens(DiscussServiceRestStub): - def __hash__(self): - return hash("CountMessageTokens") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: discuss_service.CountMessageTokensRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> discuss_service.CountMessageTokensResponse: - r"""Call the count message tokens method over HTTP. - - Args: - request (~.discuss_service.CountMessageTokensRequest): - The request object. Counts the number of tokens in the ``prompt`` sent to a - model. - - Models may tokenize text differently, so each model may - return a different ``token_count``. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.discuss_service.CountMessageTokensResponse: - A response from ``CountMessageTokens``. - - It returns the model's ``token_count`` for the - ``prompt``. - - """ - - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta2/{model=models/*}:countMessageTokens', - 'body': '*', - }, - ] - request, metadata = self._interceptor.pre_count_message_tokens(request, metadata) - pb_request = discuss_service.CountMessageTokensRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body - - body = json_format.MessageToJson( - transcoded_request['body'], - including_default_value_fields=False, - use_integers_for_enums=True - ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] - - # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) - query_params.update(self._get_unset_required_fields(query_params)) - - query_params["$alt"] = "json;enum-encoding=int" - - # Send the request - headers = dict(metadata) - headers['Content-Type'] = 'application/json' - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, - ) - - # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception - # subclass. - if response.status_code >= 400: - raise core_exceptions.from_http_response(response) - - # Return the response - resp = discuss_service.CountMessageTokensResponse() - pb_resp = discuss_service.CountMessageTokensResponse.pb(resp) - - json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) - resp = self._interceptor.post_count_message_tokens(resp) - return resp - - class _GenerateMessage(DiscussServiceRestStub): - def __hash__(self): - return hash("GenerateMessage") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: discuss_service.GenerateMessageRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> discuss_service.GenerateMessageResponse: - r"""Call the generate message method over HTTP. - - Args: - request (~.discuss_service.GenerateMessageRequest): - The request object. Request to generate a message - response from the model. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.discuss_service.GenerateMessageResponse: - The response from the model. - - This includes candidate messages and - conversation history in the form of - chronologically-ordered messages. - - """ - - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta2/{model=models/*}:generateMessage', - 'body': '*', - }, - ] - request, metadata = self._interceptor.pre_generate_message(request, metadata) - pb_request = discuss_service.GenerateMessageRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body - - body = json_format.MessageToJson( - transcoded_request['body'], - including_default_value_fields=False, - use_integers_for_enums=True - ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] - - # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) - query_params.update(self._get_unset_required_fields(query_params)) - - query_params["$alt"] = "json;enum-encoding=int" - - # Send the request - headers = dict(metadata) - headers['Content-Type'] = 'application/json' - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, - ) - - # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception - # subclass. - if response.status_code >= 400: - raise core_exceptions.from_http_response(response) - - # Return the response - resp = discuss_service.GenerateMessageResponse() - pb_resp = discuss_service.GenerateMessageResponse.pb(resp) - - json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) - resp = self._interceptor.post_generate_message(resp) - return resp - - @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - discuss_service.CountMessageTokensResponse]: - # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. - # In C++ this would require a dynamic_cast - return self._CountMessageTokens(self._session, self._host, self._interceptor) # type: ignore - - @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - discuss_service.GenerateMessageResponse]: - # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. - # In C++ this would require a dynamic_cast - return self._GenerateMessage(self._session, self._host, self._interceptor) # type: ignore - - @property - def kind(self) -> str: - return "rest" - - def close(self): - self._session.close() - - -__all__=( - 'DiscussServiceRestTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py deleted file mode 100644 index 2c368b92d844..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from .client import ModelServiceClient -from .async_client import ModelServiceAsyncClient - -__all__ = ( - 'ModelServiceClient', - 'ModelServiceAsyncClient', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py deleted file mode 100644 index 4710e8d992c2..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/async_client.py +++ /dev/null @@ -1,431 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -import functools -import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -from google.api_core.client_options import ClientOptions -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - -from google.ai.generativelanguage_v1beta2.services.model_service import pagers -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service -from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport -from .client import ModelServiceClient - - -class ModelServiceAsyncClient: - """Provides methods for getting metadata information about - Generative Models. - """ - - _client: ModelServiceClient - - DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT - - model_path = staticmethod(ModelServiceClient.model_path) - parse_model_path = staticmethod(ModelServiceClient.parse_model_path) - common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) - common_folder_path = staticmethod(ModelServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) - common_organization_path = staticmethod(ModelServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) - common_project_path = staticmethod(ModelServiceClient.common_project_path) - parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) - common_location_path = staticmethod(ModelServiceClient.common_location_path) - parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) - - @classmethod - def from_service_account_info(cls, info: dict, *args, **kwargs): - """Creates an instance of this client using the provided credentials - info. - - Args: - info (dict): The service account private key info. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - ModelServiceAsyncClient: The constructed client. - """ - return ModelServiceClient.from_service_account_info.__func__(ModelServiceAsyncClient, info, *args, **kwargs) # type: ignore - - @classmethod - def from_service_account_file(cls, filename: str, *args, **kwargs): - """Creates an instance of this client using the provided credentials - file. - - Args: - filename (str): The path to the service account private key json - file. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - ModelServiceAsyncClient: The constructed client. - """ - return ModelServiceClient.from_service_account_file.__func__(ModelServiceAsyncClient, filename, *args, **kwargs) # type: ignore - - from_service_account_json = from_service_account_file - - @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): - """Return the API endpoint and client cert source for mutual TLS. - - The client cert source is determined in the following order: - (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the - client cert source is None. - (2) if `client_options.client_cert_source` is provided, use the provided one; if the - default client cert source exists, use the default one; otherwise the client cert - source is None. - - The API endpoint is determined in the following order: - (1) if `client_options.api_endpoint` if provided, use the provided one. - (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the - default mTLS endpoint; if the environment variable is "never", use the default API - endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise - use the default API endpoint. - - More details can be found at https://google.aip.dev/auth/4114. - - Args: - client_options (google.api_core.client_options.ClientOptions): Custom options for the - client. Only the `api_endpoint` and `client_cert_source` properties may be used - in this method. - - Returns: - Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the - client cert source to use. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If any errors happen. - """ - return ModelServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore - - @property - def transport(self) -> ModelServiceTransport: - """Returns the transport used by the client instance. - - Returns: - ModelServiceTransport: The transport used by the client instance. - """ - return self._client.transport - - get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport] = "grpc_asyncio", - client_options: Optional[ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiates the model service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.ModelServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - self._client = ModelServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - - ) - - async def get_model(self, - request: Optional[Union[model_service.GetModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: - r"""Gets information about a specific Model. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - async def sample_get_model(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceAsyncClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.GetModelRequest( - name="name_value", - ) - - # Make the request - response = await client.get_model(request=request) - - # Handle the response - print(response) - - Args: - request (Optional[Union[google.ai.generativelanguage_v1beta2.types.GetModelRequest, dict]]): - The request object. Request for getting information about - a specific Model. - name (:class:`str`): - Required. The resource name of the model. - - This name should match a model name returned by the - ``ListModels`` method. - - Format: ``models/{model}`` - - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.Model: - Information about a Generative - Language Model. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") - - request = model_service.GetModelRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.get_model, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), - ) - - # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - async def list_models(self, - request: Optional[Union[model_service.ListModelsRequest, dict]] = None, - *, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: - r"""Lists models available through the API. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - async def sample_list_models(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceAsyncClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.ListModelsRequest( - ) - - # Make the request - page_result = client.list_models(request=request) - - # Handle the response - async for response in page_result: - print(response) - - Args: - request (Optional[Union[google.ai.generativelanguage_v1beta2.types.ListModelsRequest, dict]]): - The request object. Request for listing all Models. - page_size (:class:`int`): - The maximum number of ``Models`` to return (per page). - - The service may return fewer models. If unspecified, at - most 50 models will be returned per page. This method - returns at most 1000 models per page, even if you pass a - larger page_size. - - This corresponds to the ``page_size`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - page_token (:class:`str`): - A page token, received from a previous ``ListModels`` - call. - - Provide the ``page_token`` returned by one request as an - argument to the next request to retrieve the next page. - - When paginating, all other parameters provided to - ``ListModels`` must match the call that provided the - page token. - - This corresponds to the ``page_token`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsAsyncPager: - Response from ListModel containing a paginated list of - Models. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([page_size, page_token]) - if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") - - request = model_service.ListModelsRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if page_size is not None: - request.page_size = page_size - if page_token is not None: - request.page_token = page_token - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.list_models, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # This method is paged; wrap the response in a pager, which provides - # an `__aiter__` convenience method. - response = pagers.ListModelsAsyncPager( - method=rpc, - request=request, - response=response, - metadata=metadata, - ) - - # Done; return the response. - return response - - async def __aenter__(self) -> "ModelServiceAsyncClient": - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.transport.close() - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "ModelServiceAsyncClient", -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py deleted file mode 100644 index 9bcf43c759e5..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/client.py +++ /dev/null @@ -1,635 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -import os -import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -from google.api_core import client_options as client_options_lib -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - -from google.ai.generativelanguage_v1beta2.services.model_service import pagers -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service -from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc import ModelServiceGrpcTransport -from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport -from .transports.rest import ModelServiceRestTransport - - -class ModelServiceClientMeta(type): - """Metaclass for the ModelService client. - - This provides class-level methods for building and retrieving - support objects (e.g. transport) without polluting the client instance - objects. - """ - _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] - _transport_registry["grpc"] = ModelServiceGrpcTransport - _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport - _transport_registry["rest"] = ModelServiceRestTransport - - def get_transport_class(cls, - label: Optional[str] = None, - ) -> Type[ModelServiceTransport]: - """Returns an appropriate transport class. - - Args: - label: The name of the desired transport. If none is - provided, then the first transport in the registry is used. - - Returns: - The transport class to use. - """ - # If a specific transport is requested, return that one. - if label: - return cls._transport_registry[label] - - # No transport is requested; return the default (that is, the first one - # in the dictionary). - return next(iter(cls._transport_registry.values())) - - -class ModelServiceClient(metaclass=ModelServiceClientMeta): - """Provides methods for getting metadata information about - Generative Models. - """ - - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Converts api endpoint to mTLS endpoint. - - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT - ) - - @classmethod - def from_service_account_info(cls, info: dict, *args, **kwargs): - """Creates an instance of this client using the provided credentials - info. - - Args: - info (dict): The service account private key info. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - ModelServiceClient: The constructed client. - """ - credentials = service_account.Credentials.from_service_account_info(info) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) - - @classmethod - def from_service_account_file(cls, filename: str, *args, **kwargs): - """Creates an instance of this client using the provided credentials - file. - - Args: - filename (str): The path to the service account private key json - file. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - ModelServiceClient: The constructed client. - """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) - - from_service_account_json = from_service_account_file - - @property - def transport(self) -> ModelServiceTransport: - """Returns the transport used by the client instance. - - Returns: - ModelServiceTransport: The transport used by the client - instance. - """ - return self._transport - - @staticmethod - def model_path(model: str,) -> str: - """Returns a fully-qualified model string.""" - return "models/{model}".format(model=model, ) - - @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: - """Parses a model path into its component segments.""" - m = re.match(r"^models/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: - """Returns a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) - - @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: - """Parse a billing_account path into its component segments.""" - m = re.match(r"^billingAccounts/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_folder_path(folder: str, ) -> str: - """Returns a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) - - @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: - """Parse a folder path into its component segments.""" - m = re.match(r"^folders/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_organization_path(organization: str, ) -> str: - """Returns a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) - - @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: - """Parse a organization path into its component segments.""" - m = re.match(r"^organizations/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_project_path(project: str, ) -> str: - """Returns a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) - - @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: - """Parse a project path into its component segments.""" - m = re.match(r"^projects/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_location_path(project: str, location: str, ) -> str: - """Returns a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) - - @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: - """Parse a location path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) - return m.groupdict() if m else {} - - @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): - """Return the API endpoint and client cert source for mutual TLS. - - The client cert source is determined in the following order: - (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the - client cert source is None. - (2) if `client_options.client_cert_source` is provided, use the provided one; if the - default client cert source exists, use the default one; otherwise the client cert - source is None. - - The API endpoint is determined in the following order: - (1) if `client_options.api_endpoint` if provided, use the provided one. - (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the - default mTLS endpoint; if the environment variable is "never", use the default API - endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise - use the default API endpoint. - - More details can be found at https://google.aip.dev/auth/4114. - - Args: - client_options (google.api_core.client_options.ClientOptions): Custom options for the - client. Only the `api_endpoint` and `client_cert_source` properties may be used - in this method. - - Returns: - Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the - client cert source to use. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If any errors happen. - """ - if client_options is None: - client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") - use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") - if use_mtls_endpoint not in ("auto", "never", "always"): - raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") - - # Figure out the client cert source to use. - client_cert_source = None - if use_client_cert == "true": - if client_options.client_cert_source: - client_cert_source = client_options.client_cert_source - elif mtls.has_default_client_cert_source(): - client_cert_source = mtls.default_client_cert_source() - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): - api_endpoint = cls.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = cls.DEFAULT_ENDPOINT - - return api_endpoint, client_cert_source - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ModelServiceTransport]] = None, - client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiates the model service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ModelServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - """ - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) - - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) - - api_key_value = getattr(client_options, "api_key", None) - if api_key_value and credentials: - raise ValueError("client_options.api_key and credentials are mutually exclusive") - - # Save or instantiate the transport. - # Ordinarily, we provide the transport, but allowing a custom transport - # instance provides an extensibility point for unusual situations. - if isinstance(transport, ModelServiceTransport): - # transport is a ModelServiceTransport instance. - if credentials or client_options.credentials_file or api_key_value: - raise ValueError("When providing a transport instance, " - "provide its credentials directly.") - if client_options.scopes: - raise ValueError( - "When providing a transport instance, provide its scopes " - "directly." - ) - self._transport = transport - else: - import google.auth._default # type: ignore - - if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): - credentials = google.auth._default.get_api_key_credentials(api_key_value) - - Transport = type(self).get_transport_class(transport) - self._transport = Transport( - credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, - client_info=client_info, - always_use_jwt_access=True, - api_audience=client_options.api_audience, - ) - - def get_model(self, - request: Optional[Union[model_service.GetModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: - r"""Gets information about a specific Model. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - def sample_get_model(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.GetModelRequest( - name="name_value", - ) - - # Make the request - response = client.get_model(request=request) - - # Handle the response - print(response) - - Args: - request (Union[google.ai.generativelanguage_v1beta2.types.GetModelRequest, dict]): - The request object. Request for getting information about - a specific Model. - name (str): - Required. The resource name of the model. - - This name should match a model name returned by the - ``ListModels`` method. - - Format: ``models/{model}`` - - This corresponds to the ``name`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.Model: - Information about a Generative - Language Model. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([name]) - if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') - - # Minor optimization to avoid making a copy if the user passes - # in a model_service.GetModelRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.GetModelRequest): - request = model_service.GetModelRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if name is not None: - request.name = name - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.get_model] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), - ) - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - def list_models(self, - request: Optional[Union[model_service.ListModelsRequest, dict]] = None, - *, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: - r"""Lists models available through the API. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - def sample_list_models(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.ListModelsRequest( - ) - - # Make the request - page_result = client.list_models(request=request) - - # Handle the response - for response in page_result: - print(response) - - Args: - request (Union[google.ai.generativelanguage_v1beta2.types.ListModelsRequest, dict]): - The request object. Request for listing all Models. - page_size (int): - The maximum number of ``Models`` to return (per page). - - The service may return fewer models. If unspecified, at - most 50 models will be returned per page. This method - returns at most 1000 models per page, even if you pass a - larger page_size. - - This corresponds to the ``page_size`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - page_token (str): - A page token, received from a previous ``ListModels`` - call. - - Provide the ``page_token`` returned by one request as an - argument to the next request to retrieve the next page. - - When paginating, all other parameters provided to - ``ListModels`` must match the call that provided the - page token. - - This corresponds to the ``page_token`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsPager: - Response from ListModel containing a paginated list of - Models. - - Iterating over this object will yield results and - resolve additional pages automatically. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([page_size, page_token]) - if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') - - # Minor optimization to avoid making a copy if the user passes - # in a model_service.ListModelsRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, model_service.ListModelsRequest): - request = model_service.ListModelsRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if page_size is not None: - request.page_size = page_size - if page_token is not None: - request.page_token = page_token - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.list_models] - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # This method is paged; wrap the response in a pager, which provides - # an `__iter__` convenience method. - response = pagers.ListModelsPager( - method=rpc, - request=request, - response=response, - metadata=metadata, - ) - - # Done; return the response. - return response - - def __enter__(self) -> "ModelServiceClient": - return self - - def __exit__(self, type, value, traceback): - """Releases underlying transport's resources. - - .. warning:: - ONLY use as a context manager if the transport is NOT shared - with other clients! Exiting the with block will CLOSE the transport - and may cause errors in other clients! - """ - self.transport.close() - - - - - - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "ModelServiceClient", -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py deleted file mode 100644 index 2183050a4126..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/pagers.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Any, AsyncIterator, Awaitable, Callable, Sequence, Tuple, Optional, Iterator - -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service - - -class ListModelsPager: - """A pager for iterating through ``list_models`` requests. - - This class thinly wraps an initial - :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` object, and - provides an ``__iter__`` method to iterate through its - ``models`` field. - - If there are more pages, the ``__iter__`` method will make additional - ``ListModels`` requests and continue to iterate - through the ``models`` field on the - corresponding responses. - - All the usual :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - def __init__(self, - method: Callable[..., model_service.ListModelsResponse], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): - """Instantiate the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (google.ai.generativelanguage_v1beta2.types.ListModelsRequest): - The initial request object. - response (google.ai.generativelanguage_v1beta2.types.ListModelsResponse): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = model_service.ListModelsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - def pages(self) -> Iterator[model_service.ListModelsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = self._method(self._request, metadata=self._metadata) - yield self._response - - def __iter__(self) -> Iterator[model.Model]: - for page in self.pages: - yield from page.models - - def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) - - -class ListModelsAsyncPager: - """A pager for iterating through ``list_models`` requests. - - This class thinly wraps an initial - :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` object, and - provides an ``__aiter__`` method to iterate through its - ``models`` field. - - If there are more pages, the ``__aiter__`` method will make additional - ``ListModels`` requests and continue to iterate - through the ``models`` field on the - corresponding responses. - - All the usual :class:`google.ai.generativelanguage_v1beta2.types.ListModelsResponse` - attributes are available on the pager. If multiple requests are made, only - the most recent response is retained, and thus used for attribute lookup. - """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): - """Instantiates the pager. - - Args: - method (Callable): The method that was originally called, and - which instantiated this pager. - request (google.ai.generativelanguage_v1beta2.types.ListModelsRequest): - The initial request object. - response (google.ai.generativelanguage_v1beta2.types.ListModelsResponse): - The initial response object. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - """ - self._method = method - self._request = model_service.ListModelsRequest(request) - self._response = response - self._metadata = metadata - - def __getattr__(self, name: str) -> Any: - return getattr(self._response, name) - - @property - async def pages(self) -> AsyncIterator[model_service.ListModelsResponse]: - yield self._response - while self._response.next_page_token: - self._request.page_token = self._response.next_page_token - self._response = await self._method(self._request, metadata=self._metadata) - yield self._response - def __aiter__(self) -> AsyncIterator[model.Model]: - async def async_generator(): - async for page in self.pages: - for response in page.models: - yield response - - return async_generator() - - def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py deleted file mode 100644 index 3f41738067e8..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/base.py +++ /dev/null @@ -1,162 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import abc -from typing import Awaitable, Callable, Dict, Optional, Sequence, Union - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -import google.auth # type: ignore -import google.api_core -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -class ModelServiceTransport(abc.ABC): - """Abstract transport class for ModelService.""" - - AUTH_SCOPES = ( - ) - - DEFAULT_HOST: str = 'generativelanguage.googleapis.com' - def __init__( - self, *, - host: str = DEFAULT_HOST, - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - **kwargs, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - """ - - scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} - - # Save the scopes. - self._scopes = scopes - - # If no credentials are provided, then determine the appropriate - # defaults. - if credentials and credentials_file: - raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") - - if credentials_file is not None: - credentials, _ = google.auth.load_credentials_from_file( - credentials_file, - **scopes_kwargs, - quota_project_id=quota_project_id - ) - elif credentials is None: - credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) - # Don't apply audience if the credentials file passed from user. - if hasattr(credentials, "with_gdch_audience"): - credentials = credentials.with_gdch_audience(api_audience if api_audience else host) - - # If the credentials are service account credentials, then always try to use self signed JWT. - if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): - credentials = credentials.with_always_use_jwt_access(True) - - # Save the credentials. - self._credentials = credentials - - # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' - self._host = host - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.get_model: gapic_v1.method.wrap_method( - self.get_model, - default_timeout=None, - client_info=client_info, - ), - self.list_models: gapic_v1.method.wrap_method( - self.list_models, - default_timeout=None, - client_info=client_info, - ), - } - - def close(self): - """Closes resources associated with the transport. - - .. warning:: - Only call this method if the transport is NOT shared - with other clients - this may cause errors in other clients! - """ - raise NotImplementedError() - - @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - Union[ - model.Model, - Awaitable[model.Model] - ]]: - raise NotImplementedError() - - @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - Union[ - model_service.ListModelsResponse, - Awaitable[model_service.ListModelsResponse] - ]]: - raise NotImplementedError() - - @property - def kind(self) -> str: - raise NotImplementedError() - - -__all__ = ( - 'ModelServiceTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py deleted file mode 100644 index 892193957a68..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc.py +++ /dev/null @@ -1,292 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple, Union - -from google.api_core import grpc_helpers -from google.api_core import gapic_v1 -import google.auth # type: ignore -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore - -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO - - -class ModelServiceGrpcTransport(ModelServiceTransport): - """gRPC backend transport for ModelService. - - Provides methods for getting metadata information about - Generative Models. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - _stubs: Dict[str, Callable] - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - self._grpc_channel = None - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if channel: - # Ignore credentials if a channel was passed. - credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - self._ssl_channel_credentials = None - - else: - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - - if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: - """Create and return a gRPC channel object. - Args: - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - - return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs - ) - - @property - def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ - return self._grpc_channel - - @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - model.Model]: - r"""Return a callable for the get model method over gRPC. - - Gets information about a specific Model. - - Returns: - Callable[[~.GetModelRequest], - ~.Model]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.ModelService/GetModel', - request_serializer=model_service.GetModelRequest.serialize, - response_deserializer=model.Model.deserialize, - ) - return self._stubs['get_model'] - - @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - model_service.ListModelsResponse]: - r"""Return a callable for the list models method over gRPC. - - Lists models available through the API. - - Returns: - Callable[[~.ListModelsRequest], - ~.ListModelsResponse]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.ModelService/ListModels', - request_serializer=model_service.ListModelsRequest.serialize, - response_deserializer=model_service.ListModelsResponse.deserialize, - ) - return self._stubs['list_models'] - - def close(self): - self.grpc_channel.close() - - @property - def kind(self) -> str: - return "grpc" - - -__all__ = ( - 'ModelServiceGrpcTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py deleted file mode 100644 index 49b3a42dee4c..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,291 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union - -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import ModelServiceGrpcTransport - - -class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): - """gRPC AsyncIO backend transport for ModelService. - - Provides methods for getting metadata information about - Generative Models. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs - ) - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - self._grpc_channel = None - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if channel: - # Ignore credentials if a channel was passed. - credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - self._ssl_channel_credentials = None - else: - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - - if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - Awaitable[model.Model]]: - r"""Return a callable for the get model method over gRPC. - - Gets information about a specific Model. - - Returns: - Callable[[~.GetModelRequest], - Awaitable[~.Model]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.ModelService/GetModel', - request_serializer=model_service.GetModelRequest.serialize, - response_deserializer=model.Model.deserialize, - ) - return self._stubs['get_model'] - - @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - Awaitable[model_service.ListModelsResponse]]: - r"""Return a callable for the list models method over gRPC. - - Lists models available through the API. - - Returns: - Callable[[~.ListModelsRequest], - Awaitable[~.ListModelsResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.ModelService/ListModels', - request_serializer=model_service.ListModelsRequest.serialize, - response_deserializer=model_service.ListModelsResponse.deserialize, - ) - return self._stubs['list_models'] - - def close(self): - return self.grpc_channel.close() - - -__all__ = ( - 'ModelServiceGrpcAsyncIOTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py deleted file mode 100644 index db28ab8b2b81..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/rest.py +++ /dev/null @@ -1,397 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from google.auth.transport.requests import AuthorizedSession # type: ignore -import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore -from google.api_core import exceptions as core_exceptions -from google.api_core import retry as retries -from google.api_core import rest_helpers -from google.api_core import rest_streaming -from google.api_core import path_template -from google.api_core import gapic_v1 - -from google.protobuf import json_format -from requests import __version__ as requests_version -import dataclasses -import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import warnings - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - - -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service - -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, - grpc_version=None, - rest_version=requests_version, -) - - -class ModelServiceRestInterceptor: - """Interceptor for ModelService. - - Interceptors are used to manipulate requests, request metadata, and responses - in arbitrary ways. - Example use cases include: - * Logging - * Verifying requests according to service or custom semantics - * Stripping extraneous information from responses - - These use cases and more can be enabled by injecting an - instance of a custom subclass when constructing the ModelServiceRestTransport. - - .. code-block:: python - class MyCustomModelServiceInterceptor(ModelServiceRestInterceptor): - def pre_get_model(self, request, metadata): - logging.log(f"Received request: {request}") - return request, metadata - - def post_get_model(self, response): - logging.log(f"Received response: {response}") - return response - - def pre_list_models(self, request, metadata): - logging.log(f"Received request: {request}") - return request, metadata - - def post_list_models(self, response): - logging.log(f"Received response: {response}") - return response - - transport = ModelServiceRestTransport(interceptor=MyCustomModelServiceInterceptor()) - client = ModelServiceClient(transport=transport) - - - """ - def pre_get_model(self, request: model_service.GetModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.GetModelRequest, Sequence[Tuple[str, str]]]: - """Pre-rpc interceptor for get_model - - Override in a subclass to manipulate the request or metadata - before they are sent to the ModelService server. - """ - return request, metadata - - def post_get_model(self, response: model.Model) -> model.Model: - """Post-rpc interceptor for get_model - - Override in a subclass to manipulate the response - after it is returned by the ModelService server but before - it is returned to user code. - """ - return response - def pre_list_models(self, request: model_service.ListModelsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.ListModelsRequest, Sequence[Tuple[str, str]]]: - """Pre-rpc interceptor for list_models - - Override in a subclass to manipulate the request or metadata - before they are sent to the ModelService server. - """ - return request, metadata - - def post_list_models(self, response: model_service.ListModelsResponse) -> model_service.ListModelsResponse: - """Post-rpc interceptor for list_models - - Override in a subclass to manipulate the response - after it is returned by the ModelService server but before - it is returned to user code. - """ - return response - - -@dataclasses.dataclass -class ModelServiceRestStub: - _session: AuthorizedSession - _host: str - _interceptor: ModelServiceRestInterceptor - - -class ModelServiceRestTransport(ModelServiceTransport): - """REST backend transport for ModelService. - - Provides methods for getting metadata information about - Generative Models. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends JSON representations of protocol buffers over HTTP/1.1 - - """ - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - client_cert_source_for_mtls: Optional[Callable[[ - ], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - url_scheme: str = 'https', - interceptor: Optional[ModelServiceRestInterceptor] = None, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. - """ - # Run the base constructor - # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. - # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the - # credentials object - maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) - if maybe_url_match is None: - raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER - - url_match_items = maybe_url_match.groupdict() - - host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host - - super().__init__( - host=host, - credentials=credentials, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience - ) - self._session = AuthorizedSession( - self._credentials, default_host=self.DEFAULT_HOST) - if client_cert_source_for_mtls: - self._session.configure_mtls_channel(client_cert_source_for_mtls) - self._interceptor = interceptor or ModelServiceRestInterceptor() - self._prep_wrapped_messages(client_info) - - class _GetModel(ModelServiceRestStub): - def __hash__(self): - return hash("GetModel") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: model_service.GetModelRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> model.Model: - r"""Call the get model method over HTTP. - - Args: - request (~.model_service.GetModelRequest): - The request object. Request for getting information about - a specific Model. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.model.Model: - Information about a Generative - Language Model. - - """ - - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta2/{name=models/*}', - }, - ] - request, metadata = self._interceptor.pre_get_model(request, metadata) - pb_request = model_service.GetModelRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - uri = transcoded_request['uri'] - method = transcoded_request['method'] - - # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) - query_params.update(self._get_unset_required_fields(query_params)) - - query_params["$alt"] = "json;enum-encoding=int" - - # Send the request - headers = dict(metadata) - headers['Content-Type'] = 'application/json' - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - ) - - # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception - # subclass. - if response.status_code >= 400: - raise core_exceptions.from_http_response(response) - - # Return the response - resp = model.Model() - pb_resp = model.Model.pb(resp) - - json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) - resp = self._interceptor.post_get_model(resp) - return resp - - class _ListModels(ModelServiceRestStub): - def __hash__(self): - return hash("ListModels") - - def __call__(self, - request: model_service.ListModelsRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> model_service.ListModelsResponse: - r"""Call the list models method over HTTP. - - Args: - request (~.model_service.ListModelsRequest): - The request object. Request for listing all Models. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.model_service.ListModelsResponse: - Response from ``ListModel`` containing a paginated list - of Models. - - """ - - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta2/models', - }, - ] - request, metadata = self._interceptor.pre_list_models(request, metadata) - pb_request = model_service.ListModelsRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - uri = transcoded_request['uri'] - method = transcoded_request['method'] - - # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) - - query_params["$alt"] = "json;enum-encoding=int" - - # Send the request - headers = dict(metadata) - headers['Content-Type'] = 'application/json' - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - ) - - # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception - # subclass. - if response.status_code >= 400: - raise core_exceptions.from_http_response(response) - - # Return the response - resp = model_service.ListModelsResponse() - pb_resp = model_service.ListModelsResponse.pb(resp) - - json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) - resp = self._interceptor.post_list_models(resp) - return resp - - @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - model.Model]: - # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. - # In C++ this would require a dynamic_cast - return self._GetModel(self._session, self._host, self._interceptor) # type: ignore - - @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - model_service.ListModelsResponse]: - # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. - # In C++ this would require a dynamic_cast - return self._ListModels(self._session, self._host, self._interceptor) # type: ignore - - @property - def kind(self) -> str: - return "rest" - - def close(self): - self._session.close() - - -__all__=( - 'ModelServiceRestTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py deleted file mode 100644 index a063956d2782..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/async_client.py +++ /dev/null @@ -1,514 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -import functools -import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -from google.api_core.client_options import ClientOptions -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - -from google.ai.generativelanguage_v1beta2.types import safety -from google.ai.generativelanguage_v1beta2.types import text_service -from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport -from .client import TextServiceClient - - -class TextServiceAsyncClient: - """API for using Generative Language Models (GLMs) trained to - generate text. - Also known as Large Language Models (LLM)s, these generate text - given an input prompt from the user. - """ - - _client: TextServiceClient - - DEFAULT_ENDPOINT = TextServiceClient.DEFAULT_ENDPOINT - DEFAULT_MTLS_ENDPOINT = TextServiceClient.DEFAULT_MTLS_ENDPOINT - - model_path = staticmethod(TextServiceClient.model_path) - parse_model_path = staticmethod(TextServiceClient.parse_model_path) - common_billing_account_path = staticmethod(TextServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(TextServiceClient.parse_common_billing_account_path) - common_folder_path = staticmethod(TextServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(TextServiceClient.parse_common_folder_path) - common_organization_path = staticmethod(TextServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(TextServiceClient.parse_common_organization_path) - common_project_path = staticmethod(TextServiceClient.common_project_path) - parse_common_project_path = staticmethod(TextServiceClient.parse_common_project_path) - common_location_path = staticmethod(TextServiceClient.common_location_path) - parse_common_location_path = staticmethod(TextServiceClient.parse_common_location_path) - - @classmethod - def from_service_account_info(cls, info: dict, *args, **kwargs): - """Creates an instance of this client using the provided credentials - info. - - Args: - info (dict): The service account private key info. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - TextServiceAsyncClient: The constructed client. - """ - return TextServiceClient.from_service_account_info.__func__(TextServiceAsyncClient, info, *args, **kwargs) # type: ignore - - @classmethod - def from_service_account_file(cls, filename: str, *args, **kwargs): - """Creates an instance of this client using the provided credentials - file. - - Args: - filename (str): The path to the service account private key json - file. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - TextServiceAsyncClient: The constructed client. - """ - return TextServiceClient.from_service_account_file.__func__(TextServiceAsyncClient, filename, *args, **kwargs) # type: ignore - - from_service_account_json = from_service_account_file - - @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): - """Return the API endpoint and client cert source for mutual TLS. - - The client cert source is determined in the following order: - (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the - client cert source is None. - (2) if `client_options.client_cert_source` is provided, use the provided one; if the - default client cert source exists, use the default one; otherwise the client cert - source is None. - - The API endpoint is determined in the following order: - (1) if `client_options.api_endpoint` if provided, use the provided one. - (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the - default mTLS endpoint; if the environment variable is "never", use the default API - endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise - use the default API endpoint. - - More details can be found at https://google.aip.dev/auth/4114. - - Args: - client_options (google.api_core.client_options.ClientOptions): Custom options for the - client. Only the `api_endpoint` and `client_cert_source` properties may be used - in this method. - - Returns: - Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the - client cert source to use. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If any errors happen. - """ - return TextServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore - - @property - def transport(self) -> TextServiceTransport: - """Returns the transport used by the client instance. - - Returns: - TextServiceTransport: The transport used by the client instance. - """ - return self._client.transport - - get_transport_class = functools.partial(type(TextServiceClient).get_transport_class, type(TextServiceClient)) - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, TextServiceTransport] = "grpc_asyncio", - client_options: Optional[ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiates the text service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, ~.TextServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (ClientOptions): Custom options for the client. It - won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - """ - self._client = TextServiceClient( - credentials=credentials, - transport=transport, - client_options=client_options, - client_info=client_info, - - ) - - async def generate_text(self, - request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[text_service.TextPrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - max_output_tokens: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.GenerateTextResponse: - r"""Generates a response from the model given an input - message. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - async def sample_generate_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceAsyncClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.TextPrompt() - prompt.text = "text_value" - - request = generativelanguage_v1beta2.GenerateTextRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = await client.generate_text(request=request) - - # Handle the response - print(response) - - Args: - request (Optional[Union[google.ai.generativelanguage_v1beta2.types.GenerateTextRequest, dict]]): - The request object. Request to generate a text completion - response from the model. - model (:class:`str`): - Required. The model name to use with - the format name=models/{model}. - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - prompt (:class:`google.ai.generativelanguage_v1beta2.types.TextPrompt`): - Required. The free-form input text - given to the model as a prompt. - Given a prompt, the model will generate - a TextCompletion response it predicts as - the completion of the input text. - - This corresponds to the ``prompt`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - temperature (:class:`float`): - Controls the randomness of the output. Note: The default - value varies by model, see the ``Model.temperature`` - attribute of the ``Model`` returned the ``getModel`` - function. - - Values can range from [0.0,1.0], inclusive. A value - closer to 1.0 will produce responses that are more - varied and creative, while a value closer to 0.0 will - typically result in more straightforward responses from - the model. - - This corresponds to the ``temperature`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - candidate_count (:class:`int`): - Number of generated responses to return. - - This value must be between [1, 8], inclusive. If unset, - this will default to 1. - - This corresponds to the ``candidate_count`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - max_output_tokens (:class:`int`): - The maximum number of tokens to - include in a candidate. - If unset, this will default to 64. - - This corresponds to the ``max_output_tokens`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_p (:class:`float`): - The maximum cumulative probability of tokens to consider - when sampling. - - The model uses combined Top-k and nucleus sampling. - - Tokens are sorted based on their assigned probabilities - so that only the most liekly tokens are considered. - Top-k sampling directly limits the maximum number of - tokens to consider, while Nucleus sampling limits number - of tokens based on the cumulative probability. - - Note: The default value varies by model, see the - ``Model.top_p`` attribute of the ``Model`` returned the - ``getModel`` function. - - This corresponds to the ``top_p`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_k (:class:`int`): - The maximum number of tokens to consider when sampling. - - The model uses combined Top-k and nucleus sampling. - - Top-k sampling considers the set of ``top_k`` most - probable tokens. Defaults to 40. - - Note: The default value varies by model, see the - ``Model.top_k`` attribute of the ``Model`` returned the - ``getModel`` function. - - This corresponds to the ``top_k`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.GenerateTextResponse: - The response from the model, - including candidate completions. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) - if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") - - request = text_service.GenerateTextRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if prompt is not None: - request.prompt = prompt - if temperature is not None: - request.temperature = temperature - if candidate_count is not None: - request.candidate_count = candidate_count - if max_output_tokens is not None: - request.max_output_tokens = max_output_tokens - if top_p is not None: - request.top_p = top_p - if top_k is not None: - request.top_k = top_k - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.generate_text, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - async def embed_text(self, - request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, - *, - model: Optional[str] = None, - text: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.EmbedTextResponse: - r"""Generates an embedding from the model given an input - message. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - async def sample_embed_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceAsyncClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.EmbedTextRequest( - model="model_value", - text="text_value", - ) - - # Make the request - response = await client.embed_text(request=request) - - # Handle the response - print(response) - - Args: - request (Optional[Union[google.ai.generativelanguage_v1beta2.types.EmbedTextRequest, dict]]): - The request object. Request to get a text embedding from - the model. - model (:class:`str`): - Required. The model name to use with - the format model=models/{model}. - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - text (:class:`str`): - Required. The free-form input text - that the model will turn into an - embedding. - - This corresponds to the ``text`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.EmbedTextResponse: - The response to a EmbedTextRequest. - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, text]) - if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") - - request = text_service.EmbedTextRequest(request) - - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if text is not None: - request.text = text - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = gapic_v1.method_async.wrap_method( - self._client._transport.embed_text, - default_timeout=None, - client_info=DEFAULT_CLIENT_INFO, - ) - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = await rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - async def __aenter__(self) -> "TextServiceAsyncClient": - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.transport.close() - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "TextServiceAsyncClient", -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py deleted file mode 100644 index 39ecd7327b22..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/client.py +++ /dev/null @@ -1,718 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -import os -import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -from google.api_core import client_options as client_options_lib -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - -from google.ai.generativelanguage_v1beta2.types import safety -from google.ai.generativelanguage_v1beta2.types import text_service -from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc import TextServiceGrpcTransport -from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport -from .transports.rest import TextServiceRestTransport - - -class TextServiceClientMeta(type): - """Metaclass for the TextService client. - - This provides class-level methods for building and retrieving - support objects (e.g. transport) without polluting the client instance - objects. - """ - _transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] - _transport_registry["grpc"] = TextServiceGrpcTransport - _transport_registry["grpc_asyncio"] = TextServiceGrpcAsyncIOTransport - _transport_registry["rest"] = TextServiceRestTransport - - def get_transport_class(cls, - label: Optional[str] = None, - ) -> Type[TextServiceTransport]: - """Returns an appropriate transport class. - - Args: - label: The name of the desired transport. If none is - provided, then the first transport in the registry is used. - - Returns: - The transport class to use. - """ - # If a specific transport is requested, return that one. - if label: - return cls._transport_registry[label] - - # No transport is requested; return the default (that is, the first one - # in the dictionary). - return next(iter(cls._transport_registry.values())) - - -class TextServiceClient(metaclass=TextServiceClientMeta): - """API for using Generative Language Models (GLMs) trained to - generate text. - Also known as Large Language Models (LLM)s, these generate text - given an input prompt from the user. - """ - - @staticmethod - def _get_default_mtls_endpoint(api_endpoint): - """Converts api endpoint to mTLS endpoint. - - Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to - "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. - Args: - api_endpoint (Optional[str]): the api endpoint to convert. - Returns: - str: converted mTLS api endpoint. - """ - if not api_endpoint: - return api_endpoint - - mtls_endpoint_re = re.compile( - r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" - ) - - m = mtls_endpoint_re.match(api_endpoint) - name, mtls, sandbox, googledomain = m.groups() - if mtls or not googledomain: - return api_endpoint - - if sandbox: - return api_endpoint.replace( - "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" - ) - - return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") - - DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" - DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore - DEFAULT_ENDPOINT - ) - - @classmethod - def from_service_account_info(cls, info: dict, *args, **kwargs): - """Creates an instance of this client using the provided credentials - info. - - Args: - info (dict): The service account private key info. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - TextServiceClient: The constructed client. - """ - credentials = service_account.Credentials.from_service_account_info(info) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) - - @classmethod - def from_service_account_file(cls, filename: str, *args, **kwargs): - """Creates an instance of this client using the provided credentials - file. - - Args: - filename (str): The path to the service account private key json - file. - args: Additional arguments to pass to the constructor. - kwargs: Additional arguments to pass to the constructor. - - Returns: - TextServiceClient: The constructed client. - """ - credentials = service_account.Credentials.from_service_account_file( - filename) - kwargs["credentials"] = credentials - return cls(*args, **kwargs) - - from_service_account_json = from_service_account_file - - @property - def transport(self) -> TextServiceTransport: - """Returns the transport used by the client instance. - - Returns: - TextServiceTransport: The transport used by the client - instance. - """ - return self._transport - - @staticmethod - def model_path(model: str,) -> str: - """Returns a fully-qualified model string.""" - return "models/{model}".format(model=model, ) - - @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: - """Parses a model path into its component segments.""" - m = re.match(r"^models/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: - """Returns a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) - - @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: - """Parse a billing_account path into its component segments.""" - m = re.match(r"^billingAccounts/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_folder_path(folder: str, ) -> str: - """Returns a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) - - @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: - """Parse a folder path into its component segments.""" - m = re.match(r"^folders/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_organization_path(organization: str, ) -> str: - """Returns a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) - - @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: - """Parse a organization path into its component segments.""" - m = re.match(r"^organizations/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_project_path(project: str, ) -> str: - """Returns a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) - - @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: - """Parse a project path into its component segments.""" - m = re.match(r"^projects/(?P.+?)$", path) - return m.groupdict() if m else {} - - @staticmethod - def common_location_path(project: str, location: str, ) -> str: - """Returns a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) - - @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: - """Parse a location path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) - return m.groupdict() if m else {} - - @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): - """Return the API endpoint and client cert source for mutual TLS. - - The client cert source is determined in the following order: - (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the - client cert source is None. - (2) if `client_options.client_cert_source` is provided, use the provided one; if the - default client cert source exists, use the default one; otherwise the client cert - source is None. - - The API endpoint is determined in the following order: - (1) if `client_options.api_endpoint` if provided, use the provided one. - (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the - default mTLS endpoint; if the environment variable is "never", use the default API - endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise - use the default API endpoint. - - More details can be found at https://google.aip.dev/auth/4114. - - Args: - client_options (google.api_core.client_options.ClientOptions): Custom options for the - client. Only the `api_endpoint` and `client_cert_source` properties may be used - in this method. - - Returns: - Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the - client cert source to use. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If any errors happen. - """ - if client_options is None: - client_options = client_options_lib.ClientOptions() - use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") - use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") - if use_client_cert not in ("true", "false"): - raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") - if use_mtls_endpoint not in ("auto", "never", "always"): - raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") - - # Figure out the client cert source to use. - client_cert_source = None - if use_client_cert == "true": - if client_options.client_cert_source: - client_cert_source = client_options.client_cert_source - elif mtls.has_default_client_cert_source(): - client_cert_source = mtls.default_client_cert_source() - - # Figure out which api endpoint to use. - if client_options.api_endpoint is not None: - api_endpoint = client_options.api_endpoint - elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): - api_endpoint = cls.DEFAULT_MTLS_ENDPOINT - else: - api_endpoint = cls.DEFAULT_ENDPOINT - - return api_endpoint, client_cert_source - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, TextServiceTransport]] = None, - client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: - """Instantiates the text service client. - - Args: - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - transport (Union[str, TextServiceTransport]): The - transport to use. If set to None, a transport is chosen - automatically. - client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): Custom options for the - client. It won't take effect if a ``transport`` instance is provided. - (1) The ``api_endpoint`` property can be used to override the - default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT - environment variable can also be used to override the endpoint: - "always" (always use the default mTLS endpoint), "never" (always - use the default regular endpoint) and "auto" (auto switch to the - default mTLS endpoint if client certificate is present, this is - the default value). However, the ``api_endpoint`` property takes - precedence if provided. - (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable - is "true", then the ``client_cert_source`` property can be used - to provide client certificate for mutual TLS transport. If - not provided, the default SSL client certificate will be used if - present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not - set, no client certificate will be used. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - """ - if isinstance(client_options, dict): - client_options = client_options_lib.from_dict(client_options) - if client_options is None: - client_options = client_options_lib.ClientOptions() - client_options = cast(client_options_lib.ClientOptions, client_options) - - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) - - api_key_value = getattr(client_options, "api_key", None) - if api_key_value and credentials: - raise ValueError("client_options.api_key and credentials are mutually exclusive") - - # Save or instantiate the transport. - # Ordinarily, we provide the transport, but allowing a custom transport - # instance provides an extensibility point for unusual situations. - if isinstance(transport, TextServiceTransport): - # transport is a TextServiceTransport instance. - if credentials or client_options.credentials_file or api_key_value: - raise ValueError("When providing a transport instance, " - "provide its credentials directly.") - if client_options.scopes: - raise ValueError( - "When providing a transport instance, provide its scopes " - "directly." - ) - self._transport = transport - else: - import google.auth._default # type: ignore - - if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): - credentials = google.auth._default.get_api_key_credentials(api_key_value) - - Transport = type(self).get_transport_class(transport) - self._transport = Transport( - credentials=credentials, - credentials_file=client_options.credentials_file, - host=api_endpoint, - scopes=client_options.scopes, - client_cert_source_for_mtls=client_cert_source_func, - quota_project_id=client_options.quota_project_id, - client_info=client_info, - always_use_jwt_access=True, - api_audience=client_options.api_audience, - ) - - def generate_text(self, - request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[text_service.TextPrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - max_output_tokens: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.GenerateTextResponse: - r"""Generates a response from the model given an input - message. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - def sample_generate_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.TextPrompt() - prompt.text = "text_value" - - request = generativelanguage_v1beta2.GenerateTextRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = client.generate_text(request=request) - - # Handle the response - print(response) - - Args: - request (Union[google.ai.generativelanguage_v1beta2.types.GenerateTextRequest, dict]): - The request object. Request to generate a text completion - response from the model. - model (str): - Required. The model name to use with - the format name=models/{model}. - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - prompt (google.ai.generativelanguage_v1beta2.types.TextPrompt): - Required. The free-form input text - given to the model as a prompt. - Given a prompt, the model will generate - a TextCompletion response it predicts as - the completion of the input text. - - This corresponds to the ``prompt`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - temperature (float): - Controls the randomness of the output. Note: The default - value varies by model, see the ``Model.temperature`` - attribute of the ``Model`` returned the ``getModel`` - function. - - Values can range from [0.0,1.0], inclusive. A value - closer to 1.0 will produce responses that are more - varied and creative, while a value closer to 0.0 will - typically result in more straightforward responses from - the model. - - This corresponds to the ``temperature`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - candidate_count (int): - Number of generated responses to return. - - This value must be between [1, 8], inclusive. If unset, - this will default to 1. - - This corresponds to the ``candidate_count`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - max_output_tokens (int): - The maximum number of tokens to - include in a candidate. - If unset, this will default to 64. - - This corresponds to the ``max_output_tokens`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_p (float): - The maximum cumulative probability of tokens to consider - when sampling. - - The model uses combined Top-k and nucleus sampling. - - Tokens are sorted based on their assigned probabilities - so that only the most liekly tokens are considered. - Top-k sampling directly limits the maximum number of - tokens to consider, while Nucleus sampling limits number - of tokens based on the cumulative probability. - - Note: The default value varies by model, see the - ``Model.top_p`` attribute of the ``Model`` returned the - ``getModel`` function. - - This corresponds to the ``top_p`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - top_k (int): - The maximum number of tokens to consider when sampling. - - The model uses combined Top-k and nucleus sampling. - - Top-k sampling considers the set of ``top_k`` most - probable tokens. Defaults to 40. - - Note: The default value varies by model, see the - ``Model.top_k`` attribute of the ``Model`` returned the - ``getModel`` function. - - This corresponds to the ``top_k`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.GenerateTextResponse: - The response from the model, - including candidate completions. - - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) - if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') - - # Minor optimization to avoid making a copy if the user passes - # in a text_service.GenerateTextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, text_service.GenerateTextRequest): - request = text_service.GenerateTextRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if prompt is not None: - request.prompt = prompt - if temperature is not None: - request.temperature = temperature - if candidate_count is not None: - request.candidate_count = candidate_count - if max_output_tokens is not None: - request.max_output_tokens = max_output_tokens - if top_p is not None: - request.top_p = top_p - if top_k is not None: - request.top_k = top_k - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.generate_text] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - def embed_text(self, - request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, - *, - model: Optional[str] = None, - text: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.EmbedTextResponse: - r"""Generates an embedding from the model given an input - message. - - .. code-block:: python - - # This snippet has been automatically generated and should be regarded as a - # code template only. - # It will require modifications to work: - # - It may require correct/in-range values for request initialization. - # - It may require specifying regional endpoints when creating the service - # client as shown in: - # https://googleapis.dev/python/google-api-core/latest/client_options.html - from google.ai import generativelanguage_v1beta2 - - def sample_embed_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.EmbedTextRequest( - model="model_value", - text="text_value", - ) - - # Make the request - response = client.embed_text(request=request) - - # Handle the response - print(response) - - Args: - request (Union[google.ai.generativelanguage_v1beta2.types.EmbedTextRequest, dict]): - The request object. Request to get a text embedding from - the model. - model (str): - Required. The model name to use with - the format model=models/{model}. - - This corresponds to the ``model`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - text (str): - Required. The free-form input text - that the model will turn into an - embedding. - - This corresponds to the ``text`` field - on the ``request`` instance; if ``request`` is provided, this - should not be set. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - google.ai.generativelanguage_v1beta2.types.EmbedTextResponse: - The response to a EmbedTextRequest. - """ - # Create or coerce a protobuf request object. - # Quick check: If we got a request object, we should *not* have - # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, text]) - if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') - - # Minor optimization to avoid making a copy if the user passes - # in a text_service.EmbedTextRequest. - # There's no risk of modifying the input as we've already verified - # there are no flattened fields. - if not isinstance(request, text_service.EmbedTextRequest): - request = text_service.EmbedTextRequest(request) - # If we have keyword arguments corresponding to fields on the - # request, apply these. - if model is not None: - request.model = model - if text is not None: - request.text = text - - # Wrap the RPC method; this adds retry and timeout information, - # and friendly error handling. - rpc = self._transport._wrapped_methods[self._transport.embed_text] - - # Certain fields should be provided within the metadata header; - # add these here. - metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), - ) - - # Send the request. - response = rpc( - request, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - # Done; return the response. - return response - - def __enter__(self) -> "TextServiceClient": - return self - - def __exit__(self, type, value, traceback): - """Releases underlying transport's resources. - - .. warning:: - ONLY use as a context manager if the transport is NOT shared - with other clients! Exiting the with block will CLOSE the transport - and may cause errors in other clients! - """ - self.transport.close() - - - - - - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "TextServiceClient", -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py deleted file mode 100644 index b038dec99299..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/base.py +++ /dev/null @@ -1,161 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import abc -from typing import Awaitable, Callable, Dict, Optional, Sequence, Union - -from google.ai.generativelanguage_v1beta2 import gapic_version as package_version - -import google.auth # type: ignore -import google.api_core -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore - -from google.ai.generativelanguage_v1beta2.types import text_service - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -class TextServiceTransport(abc.ABC): - """Abstract transport class for TextService.""" - - AUTH_SCOPES = ( - ) - - DEFAULT_HOST: str = 'generativelanguage.googleapis.com' - def __init__( - self, *, - host: str = DEFAULT_HOST, - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - **kwargs, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A list of scopes. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - """ - - scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} - - # Save the scopes. - self._scopes = scopes - - # If no credentials are provided, then determine the appropriate - # defaults. - if credentials and credentials_file: - raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") - - if credentials_file is not None: - credentials, _ = google.auth.load_credentials_from_file( - credentials_file, - **scopes_kwargs, - quota_project_id=quota_project_id - ) - elif credentials is None: - credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) - # Don't apply audience if the credentials file passed from user. - if hasattr(credentials, "with_gdch_audience"): - credentials = credentials.with_gdch_audience(api_audience if api_audience else host) - - # If the credentials are service account credentials, then always try to use self signed JWT. - if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): - credentials = credentials.with_always_use_jwt_access(True) - - # Save the credentials. - self._credentials = credentials - - # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' - self._host = host - - def _prep_wrapped_messages(self, client_info): - # Precompute the wrapped methods. - self._wrapped_methods = { - self.generate_text: gapic_v1.method.wrap_method( - self.generate_text, - default_timeout=None, - client_info=client_info, - ), - self.embed_text: gapic_v1.method.wrap_method( - self.embed_text, - default_timeout=None, - client_info=client_info, - ), - } - - def close(self): - """Closes resources associated with the transport. - - .. warning:: - Only call this method if the transport is NOT shared - with other clients - this may cause errors in other clients! - """ - raise NotImplementedError() - - @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - Union[ - text_service.GenerateTextResponse, - Awaitable[text_service.GenerateTextResponse] - ]]: - raise NotImplementedError() - - @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - Union[ - text_service.EmbedTextResponse, - Awaitable[text_service.EmbedTextResponse] - ]]: - raise NotImplementedError() - - @property - def kind(self) -> str: - raise NotImplementedError() - - -__all__ = ( - 'TextServiceTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py deleted file mode 100644 index 4835582937e6..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc.py +++ /dev/null @@ -1,295 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Callable, Dict, Optional, Sequence, Tuple, Union - -from google.api_core import grpc_helpers -from google.api_core import gapic_v1 -import google.auth # type: ignore -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore - -from google.ai.generativelanguage_v1beta2.types import text_service -from .base import TextServiceTransport, DEFAULT_CLIENT_INFO - - -class TextServiceGrpcTransport(TextServiceTransport): - """gRPC backend transport for TextService. - - API for using Generative Language Models (GLMs) trained to - generate text. - Also known as Large Language Models (LLM)s, these generate text - given an input prompt from the user. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - _stubs: Dict[str, Callable] - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - channel (Optional[grpc.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - self._grpc_channel = None - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if channel: - # Ignore credentials if a channel was passed. - credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - self._ssl_channel_credentials = None - - else: - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - - if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: - """Create and return a gRPC channel object. - Args: - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is mutually exclusive with credentials. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - grpc.Channel: A gRPC channel object. - - Raises: - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - - return grpc_helpers.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs - ) - - @property - def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ - return self._grpc_channel - - @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - text_service.GenerateTextResponse]: - r"""Return a callable for the generate text method over gRPC. - - Generates a response from the model given an input - message. - - Returns: - Callable[[~.GenerateTextRequest], - ~.GenerateTextResponse]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'generate_text' not in self._stubs: - self._stubs['generate_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.TextService/GenerateText', - request_serializer=text_service.GenerateTextRequest.serialize, - response_deserializer=text_service.GenerateTextResponse.deserialize, - ) - return self._stubs['generate_text'] - - @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - text_service.EmbedTextResponse]: - r"""Return a callable for the embed text method over gRPC. - - Generates an embedding from the model given an input - message. - - Returns: - Callable[[~.EmbedTextRequest], - ~.EmbedTextResponse]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'embed_text' not in self._stubs: - self._stubs['embed_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.TextService/EmbedText', - request_serializer=text_service.EmbedTextRequest.serialize, - response_deserializer=text_service.EmbedTextResponse.deserialize, - ) - return self._stubs['embed_text'] - - def close(self): - self.grpc_channel.close() - - @property - def kind(self) -> str: - return "grpc" - - -__all__ = ( - 'TextServiceGrpcTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py deleted file mode 100644 index 8a8cdeeda949..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/grpc_asyncio.py +++ /dev/null @@ -1,294 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import warnings -from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union - -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore -from grpc.experimental import aio # type: ignore - -from google.ai.generativelanguage_v1beta2.types import text_service -from .base import TextServiceTransport, DEFAULT_CLIENT_INFO -from .grpc import TextServiceGrpcTransport - - -class TextServiceGrpcAsyncIOTransport(TextServiceTransport): - """gRPC AsyncIO backend transport for TextService. - - API for using Generative Language Models (GLMs) trained to - generate text. - Also known as Large Language Models (LLM)s, these generate text - given an input prompt from the user. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends protocol buffers over the wire using gRPC (which is built on - top of HTTP/2); the ``grpcio`` package must be installed. - """ - - _grpc_channel: aio.Channel - _stubs: Dict[str, Callable] = {} - - @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: - """Create and return a gRPC AsyncIO channel object. - Args: - host (Optional[str]): The host for the channel to use. - credentials (Optional[~.Credentials]): The - authorization credentials to attach to requests. These - credentials identify this application to the service. If - none are specified, the client will attempt to ascertain - the credentials from the environment. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - kwargs (Optional[dict]): Keyword arguments, which are passed to the - channel creation. - Returns: - aio.Channel: A gRPC AsyncIO channel object. - """ - - return grpc_helpers_async.create_channel( - host, - credentials=credentials, - credentials_file=credentials_file, - quota_project_id=quota_project_id, - default_scopes=cls.AUTH_SCOPES, - scopes=scopes, - default_host=cls.DEFAULT_HOST, - **kwargs - ) - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - This argument is ignored if ``channel`` is provided. - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional[Sequence[str]]): A optional list of scopes needed for this - service. These are only used when credentials are not specified and - are passed to :func:`google.auth.default`. - channel (Optional[aio.Channel]): A ``Channel`` instance through - which to make calls. - api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. - If provided, it overrides the ``host`` argument and tries to create - a mutual TLS channel with client SSL credentials from - ``client_cert_source`` or application default SSL credentials. - client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): - Deprecated. A callback to provide client SSL certificate bytes and - private key bytes, both in PEM format. It is ignored if - ``api_mtls_endpoint`` is None. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for the grpc channel. It is ignored if ``channel`` is provided. - client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): - A callback to provide client certificate bytes and private key bytes, - both in PEM format. It is used to configure a mutual TLS channel. It is - ignored if ``channel`` or ``ssl_channel_credentials`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - - Raises: - google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport - creation failed for any reason. - google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` - and ``credentials_file`` are passed. - """ - self._grpc_channel = None - self._ssl_channel_credentials = ssl_channel_credentials - self._stubs: Dict[str, Callable] = {} - - if api_mtls_endpoint: - warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) - if client_cert_source: - warnings.warn("client_cert_source is deprecated", DeprecationWarning) - - if channel: - # Ignore credentials if a channel was passed. - credentials = False - # If a channel was explicitly provided, set it. - self._grpc_channel = channel - self._ssl_channel_credentials = None - else: - if api_mtls_endpoint: - host = api_mtls_endpoint - - # Create SSL credentials with client_cert_source or application - # default SSL credentials. - if client_cert_source: - cert, key = client_cert_source() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - else: - self._ssl_channel_credentials = SslCredentials().ssl_credentials - - else: - if client_cert_source_for_mtls and not ssl_channel_credentials: - cert, key = client_cert_source_for_mtls() - self._ssl_channel_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) - - # The base transport sets the host, credentials and scopes - super().__init__( - host=host, - credentials=credentials, - credentials_file=credentials_file, - scopes=scopes, - quota_project_id=quota_project_id, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience, - ) - - if not self._grpc_channel: - self._grpc_channel = type(self).create_channel( - self._host, - # use the credentials which are saved - credentials=self._credentials, - # Set ``credentials_file`` to ``None`` here as - # the credentials that we saved earlier should be used. - credentials_file=None, - scopes=self._scopes, - ssl_credentials=self._ssl_channel_credentials, - quota_project_id=quota_project_id, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Wrap messages. This must be done after self._grpc_channel exists - self._prep_wrapped_messages(client_info) - - @property - def grpc_channel(self) -> aio.Channel: - """Create the channel designed to connect to this service. - - This property caches on the instance; repeated calls return - the same channel. - """ - # Return the channel from cache. - return self._grpc_channel - - @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - Awaitable[text_service.GenerateTextResponse]]: - r"""Return a callable for the generate text method over gRPC. - - Generates a response from the model given an input - message. - - Returns: - Callable[[~.GenerateTextRequest], - Awaitable[~.GenerateTextResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'generate_text' not in self._stubs: - self._stubs['generate_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.TextService/GenerateText', - request_serializer=text_service.GenerateTextRequest.serialize, - response_deserializer=text_service.GenerateTextResponse.deserialize, - ) - return self._stubs['generate_text'] - - @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - Awaitable[text_service.EmbedTextResponse]]: - r"""Return a callable for the embed text method over gRPC. - - Generates an embedding from the model given an input - message. - - Returns: - Callable[[~.EmbedTextRequest], - Awaitable[~.EmbedTextResponse]]: - A function that, when called, will call the underlying RPC - on the server. - """ - # Generate a "stub function" on-the-fly which will actually make - # the request. - # gRPC handles serialization and deserialization, so we just need - # to pass in the functions for each. - if 'embed_text' not in self._stubs: - self._stubs['embed_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta2.TextService/EmbedText', - request_serializer=text_service.EmbedTextRequest.serialize, - response_deserializer=text_service.EmbedTextResponse.deserialize, - ) - return self._stubs['embed_text'] - - def close(self): - return self.grpc_channel.close() - - -__all__ = ( - 'TextServiceGrpcAsyncIOTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py deleted file mode 100644 index 2480e0dd0389..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/rest.py +++ /dev/null @@ -1,423 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from google.auth.transport.requests import AuthorizedSession # type: ignore -import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore -from google.api_core import exceptions as core_exceptions -from google.api_core import retry as retries -from google.api_core import rest_helpers -from google.api_core import rest_streaming -from google.api_core import path_template -from google.api_core import gapic_v1 - -from google.protobuf import json_format -from requests import __version__ as requests_version -import dataclasses -import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import warnings - -try: - OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] -except AttributeError: # pragma: NO COVER - OptionalRetry = Union[retries.Retry, object] # type: ignore - - -from google.ai.generativelanguage_v1beta2.types import text_service - -from .base import TextServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( - gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, - grpc_version=None, - rest_version=requests_version, -) - - -class TextServiceRestInterceptor: - """Interceptor for TextService. - - Interceptors are used to manipulate requests, request metadata, and responses - in arbitrary ways. - Example use cases include: - * Logging - * Verifying requests according to service or custom semantics - * Stripping extraneous information from responses - - These use cases and more can be enabled by injecting an - instance of a custom subclass when constructing the TextServiceRestTransport. - - .. code-block:: python - class MyCustomTextServiceInterceptor(TextServiceRestInterceptor): - def pre_embed_text(self, request, metadata): - logging.log(f"Received request: {request}") - return request, metadata - - def post_embed_text(self, response): - logging.log(f"Received response: {response}") - return response - - def pre_generate_text(self, request, metadata): - logging.log(f"Received request: {request}") - return request, metadata - - def post_generate_text(self, response): - logging.log(f"Received response: {response}") - return response - - transport = TextServiceRestTransport(interceptor=MyCustomTextServiceInterceptor()) - client = TextServiceClient(transport=transport) - - - """ - def pre_embed_text(self, request: text_service.EmbedTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.EmbedTextRequest, Sequence[Tuple[str, str]]]: - """Pre-rpc interceptor for embed_text - - Override in a subclass to manipulate the request or metadata - before they are sent to the TextService server. - """ - return request, metadata - - def post_embed_text(self, response: text_service.EmbedTextResponse) -> text_service.EmbedTextResponse: - """Post-rpc interceptor for embed_text - - Override in a subclass to manipulate the response - after it is returned by the TextService server but before - it is returned to user code. - """ - return response - def pre_generate_text(self, request: text_service.GenerateTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.GenerateTextRequest, Sequence[Tuple[str, str]]]: - """Pre-rpc interceptor for generate_text - - Override in a subclass to manipulate the request or metadata - before they are sent to the TextService server. - """ - return request, metadata - - def post_generate_text(self, response: text_service.GenerateTextResponse) -> text_service.GenerateTextResponse: - """Post-rpc interceptor for generate_text - - Override in a subclass to manipulate the response - after it is returned by the TextService server but before - it is returned to user code. - """ - return response - - -@dataclasses.dataclass -class TextServiceRestStub: - _session: AuthorizedSession - _host: str - _interceptor: TextServiceRestInterceptor - - -class TextServiceRestTransport(TextServiceTransport): - """REST backend transport for TextService. - - API for using Generative Language Models (GLMs) trained to - generate text. - Also known as Large Language Models (LLM)s, these generate text - given an input prompt from the user. - - This class defines the same methods as the primary client, so the - primary client can load the underlying transport implementation - and call it. - - It sends JSON representations of protocol buffers over HTTP/1.1 - - """ - - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - client_cert_source_for_mtls: Optional[Callable[[ - ], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - url_scheme: str = 'https', - interceptor: Optional[TextServiceRestInterceptor] = None, - api_audience: Optional[str] = None, - ) -> None: - """Instantiate the transport. - - Args: - host (Optional[str]): - The hostname to connect to. - credentials (Optional[google.auth.credentials.Credentials]): The - authorization credentials to attach to requests. These - credentials identify the application to the service; if none - are specified, the client will attempt to ascertain the - credentials from the environment. - - credentials_file (Optional[str]): A file with credentials that can - be loaded with :func:`google.auth.load_credentials_from_file`. - This argument is ignored if ``channel`` is provided. - scopes (Optional(Sequence[str])): A list of scopes. This argument is - ignored if ``channel`` is provided. - client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client - certificate to configure mutual TLS HTTP channel. It is ignored - if ``channel`` is provided. - quota_project_id (Optional[str]): An optional project to use for billing - and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you are developing - your own client library. - always_use_jwt_access (Optional[bool]): Whether self signed JWT should - be used for service account credentials. - url_scheme: the protocol scheme for the API endpoint. Normally - "https", but for testing or local servers, - "http" can be specified. - """ - # Run the base constructor - # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. - # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the - # credentials object - maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) - if maybe_url_match is None: - raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER - - url_match_items = maybe_url_match.groupdict() - - host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host - - super().__init__( - host=host, - credentials=credentials, - client_info=client_info, - always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience - ) - self._session = AuthorizedSession( - self._credentials, default_host=self.DEFAULT_HOST) - if client_cert_source_for_mtls: - self._session.configure_mtls_channel(client_cert_source_for_mtls) - self._interceptor = interceptor or TextServiceRestInterceptor() - self._prep_wrapped_messages(client_info) - - class _EmbedText(TextServiceRestStub): - def __hash__(self): - return hash("EmbedText") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: text_service.EmbedTextRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> text_service.EmbedTextResponse: - r"""Call the embed text method over HTTP. - - Args: - request (~.text_service.EmbedTextRequest): - The request object. Request to get a text embedding from - the model. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.text_service.EmbedTextResponse: - The response to a EmbedTextRequest. - """ - - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta2/{model=models/*}:embedText', - 'body': '*', - }, - ] - request, metadata = self._interceptor.pre_embed_text(request, metadata) - pb_request = text_service.EmbedTextRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body - - body = json_format.MessageToJson( - transcoded_request['body'], - including_default_value_fields=False, - use_integers_for_enums=True - ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] - - # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) - query_params.update(self._get_unset_required_fields(query_params)) - - query_params["$alt"] = "json;enum-encoding=int" - - # Send the request - headers = dict(metadata) - headers['Content-Type'] = 'application/json' - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, - ) - - # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception - # subclass. - if response.status_code >= 400: - raise core_exceptions.from_http_response(response) - - # Return the response - resp = text_service.EmbedTextResponse() - pb_resp = text_service.EmbedTextResponse.pb(resp) - - json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) - resp = self._interceptor.post_embed_text(resp) - return resp - - class _GenerateText(TextServiceRestStub): - def __hash__(self): - return hash("GenerateText") - - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } - - @classmethod - def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: text_service.GenerateTextRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> text_service.GenerateTextResponse: - r"""Call the generate text method over HTTP. - - Args: - request (~.text_service.GenerateTextRequest): - The request object. Request to generate a text completion - response from the model. - retry (google.api_core.retry.Retry): Designation of what errors, if any, - should be retried. - timeout (float): The timeout for this request. - metadata (Sequence[Tuple[str, str]]): Strings which should be - sent along with the request as metadata. - - Returns: - ~.text_service.GenerateTextResponse: - The response from the model, - including candidate completions. - - """ - - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta2/{model=models/*}:generateText', - 'body': '*', - }, - ] - request, metadata = self._interceptor.pre_generate_text(request, metadata) - pb_request = text_service.GenerateTextRequest.pb(request) - transcoded_request = path_template.transcode(http_options, pb_request) - - # Jsonify the request body - - body = json_format.MessageToJson( - transcoded_request['body'], - including_default_value_fields=False, - use_integers_for_enums=True - ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] - - # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) - query_params.update(self._get_unset_required_fields(query_params)) - - query_params["$alt"] = "json;enum-encoding=int" - - # Send the request - headers = dict(metadata) - headers['Content-Type'] = 'application/json' - response = getattr(self._session, method)( - "{host}{uri}".format(host=self._host, uri=uri), - timeout=timeout, - headers=headers, - params=rest_helpers.flatten_query_params(query_params, strict=True), - data=body, - ) - - # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception - # subclass. - if response.status_code >= 400: - raise core_exceptions.from_http_response(response) - - # Return the response - resp = text_service.GenerateTextResponse() - pb_resp = text_service.GenerateTextResponse.pb(resp) - - json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) - resp = self._interceptor.post_generate_text(resp) - return resp - - @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - text_service.EmbedTextResponse]: - # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. - # In C++ this would require a dynamic_cast - return self._EmbedText(self._session, self._host, self._interceptor) # type: ignore - - @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - text_service.GenerateTextResponse]: - # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. - # In C++ this would require a dynamic_cast - return self._GenerateText(self._session, self._host, self._interceptor) # type: ignore - - @property - def kind(self) -> str: - return "rest" - - def close(self): - self._session.close() - - -__all__=( - 'TextServiceRestTransport', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py deleted file mode 100644 index 6f8563368f76..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from .citation import ( - CitationMetadata, - CitationSource, -) -from .discuss_service import ( - CountMessageTokensRequest, - CountMessageTokensResponse, - Example, - GenerateMessageRequest, - GenerateMessageResponse, - Message, - MessagePrompt, -) -from .model import ( - Model, -) -from .model_service import ( - GetModelRequest, - ListModelsRequest, - ListModelsResponse, -) -from .safety import ( - ContentFilter, - SafetyFeedback, - SafetyRating, - SafetySetting, - HarmCategory, -) -from .text_service import ( - Embedding, - EmbedTextRequest, - EmbedTextResponse, - GenerateTextRequest, - GenerateTextResponse, - TextCompletion, - TextPrompt, -) - -__all__ = ( - 'CitationMetadata', - 'CitationSource', - 'CountMessageTokensRequest', - 'CountMessageTokensResponse', - 'Example', - 'GenerateMessageRequest', - 'GenerateMessageResponse', - 'Message', - 'MessagePrompt', - 'Model', - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'ContentFilter', - 'SafetyFeedback', - 'SafetyRating', - 'SafetySetting', - 'HarmCategory', - 'Embedding', - 'EmbedTextRequest', - 'EmbedTextResponse', - 'GenerateTextRequest', - 'GenerateTextResponse', - 'TextCompletion', - 'TextPrompt', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py deleted file mode 100644 index e4ecf054b568..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/citation.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - -from typing import MutableMapping, MutableSequence - -import proto # type: ignore - - -__protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta2', - manifest={ - 'CitationMetadata', - 'CitationSource', - }, -) - - -class CitationMetadata(proto.Message): - r"""A collection of source attributions for a piece of content. - - Attributes: - citation_sources (MutableSequence[google.ai.generativelanguage_v1beta2.types.CitationSource]): - Citations to sources for a specific response. - """ - - citation_sources: MutableSequence['CitationSource'] = proto.RepeatedField( - proto.MESSAGE, - number=1, - message='CitationSource', - ) - - -class CitationSource(proto.Message): - r"""A citation to a source for a portion of a specific response. - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - start_index (int): - Optional. Start of segment of the response - that is attributed to this source. - - Index indicates the start of the segment, - measured in bytes. - - This field is a member of `oneof`_ ``_start_index``. - end_index (int): - Optional. End of the attributed segment, - exclusive. - - This field is a member of `oneof`_ ``_end_index``. - uri (str): - Optional. URI that is attributed as a source - for a portion of the text. - - This field is a member of `oneof`_ ``_uri``. - license_ (str): - Optional. License for the GitHub project that - is attributed as a source for segment. - - License info is required for code citations. - - This field is a member of `oneof`_ ``_license``. - """ - - start_index: int = proto.Field( - proto.INT32, - number=1, - optional=True, - ) - end_index: int = proto.Field( - proto.INT32, - number=2, - optional=True, - ) - uri: str = proto.Field( - proto.STRING, - number=3, - optional=True, - ) - license_: str = proto.Field( - proto.STRING, - number=4, - optional=True, - ) - - -__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py deleted file mode 100644 index f91ed5b98bed..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/discuss_service.py +++ /dev/null @@ -1,358 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - -from typing import MutableMapping, MutableSequence - -import proto # type: ignore - -from google.ai.generativelanguage_v1beta2.types import citation -from google.ai.generativelanguage_v1beta2.types import safety - - -__protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta2', - manifest={ - 'GenerateMessageRequest', - 'GenerateMessageResponse', - 'Message', - 'MessagePrompt', - 'Example', - 'CountMessageTokensRequest', - 'CountMessageTokensResponse', - }, -) - - -class GenerateMessageRequest(proto.Message): - r"""Request to generate a message response from the model. - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - model (str): - Required. The name of the model to use. - - Format: ``name=models/{model}``. - prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): - Required. The structured textual input given - to the model as a prompt. - Given a - prompt, the model will return what it predicts - is the next message in the discussion. - temperature (float): - Optional. Controls the randomness of the output. - - Values can range over ``[0.0,1.0]``, inclusive. A value - closer to ``1.0`` will produce responses that are more - varied, while a value closer to ``0.0`` will typically - result in less surprising responses from the model. - - This field is a member of `oneof`_ ``_temperature``. - candidate_count (int): - Optional. The number of generated response messages to - return. - - This value must be between ``[1, 8]``, inclusive. If unset, - this will default to ``1``. - - This field is a member of `oneof`_ ``_candidate_count``. - top_p (float): - Optional. The maximum cumulative probability of tokens to - consider when sampling. - - The model uses combined Top-k and nucleus sampling. - - Nucleus sampling considers the smallest set of tokens whose - probability sum is at least ``top_p``. - - This field is a member of `oneof`_ ``_top_p``. - top_k (int): - Optional. The maximum number of tokens to consider when - sampling. - - The model uses combined Top-k and nucleus sampling. - - Top-k sampling considers the set of ``top_k`` most probable - tokens. - - This field is a member of `oneof`_ ``_top_k``. - """ - - model: str = proto.Field( - proto.STRING, - number=1, - ) - prompt: 'MessagePrompt' = proto.Field( - proto.MESSAGE, - number=2, - message='MessagePrompt', - ) - temperature: float = proto.Field( - proto.FLOAT, - number=3, - optional=True, - ) - candidate_count: int = proto.Field( - proto.INT32, - number=4, - optional=True, - ) - top_p: float = proto.Field( - proto.FLOAT, - number=5, - optional=True, - ) - top_k: int = proto.Field( - proto.INT32, - number=6, - optional=True, - ) - - -class GenerateMessageResponse(proto.Message): - r"""The response from the model. - - This includes candidate messages and - conversation history in the form of chronologically-ordered - messages. - - Attributes: - candidates (MutableSequence[google.ai.generativelanguage_v1beta2.types.Message]): - Candidate response messages from the model. - messages (MutableSequence[google.ai.generativelanguage_v1beta2.types.Message]): - The conversation history used by the model. - filters (MutableSequence[google.ai.generativelanguage_v1beta2.types.ContentFilter]): - A set of content filtering metadata for the prompt and - response text. - - This indicates which ``SafetyCategory``\ (s) blocked a - candidate from this response, the lowest ``HarmProbability`` - that triggered a block, and the HarmThreshold setting for - that category. - """ - - candidates: MutableSequence['Message'] = proto.RepeatedField( - proto.MESSAGE, - number=1, - message='Message', - ) - messages: MutableSequence['Message'] = proto.RepeatedField( - proto.MESSAGE, - number=2, - message='Message', - ) - filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( - proto.MESSAGE, - number=3, - message=safety.ContentFilter, - ) - - -class Message(proto.Message): - r"""The base unit of structured text. - - A ``Message`` includes an ``author`` and the ``content`` of the - ``Message``. - - The ``author`` is used to tag messages when they are fed to the - model as text. - - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - author (str): - Optional. The author of this Message. - - This serves as a key for tagging - the content of this Message when it is fed to - the model as text. - - The author can be any alphanumeric string. - content (str): - Required. The text content of the structured ``Message``. - citation_metadata (google.ai.generativelanguage_v1beta2.types.CitationMetadata): - Output only. Citation information for model-generated - ``content`` in this ``Message``. - - If this ``Message`` was generated as output from the model, - this field may be populated with attribution information for - any text included in the ``content``. This field is used - only on output. - - This field is a member of `oneof`_ ``_citation_metadata``. - """ - - author: str = proto.Field( - proto.STRING, - number=1, - ) - content: str = proto.Field( - proto.STRING, - number=2, - ) - citation_metadata: citation.CitationMetadata = proto.Field( - proto.MESSAGE, - number=3, - optional=True, - message=citation.CitationMetadata, - ) - - -class MessagePrompt(proto.Message): - r"""All of the structured input text passed to the model as a prompt. - - A ``MessagePrompt`` contains a structured set of fields that provide - context for the conversation, examples of user input/model output - message pairs that prime the model to respond in different ways, and - the conversation history or list of messages representing the - alternating turns of the conversation between the user and the - model. - - Attributes: - context (str): - Optional. Text that should be provided to the model first to - ground the response. - - If not empty, this ``context`` will be given to the model - first before the ``examples`` and ``messages``. When using a - ``context`` be sure to provide it with every request to - maintain continuity. - - This field can be a description of your prompt to the model - to help provide context and guide the responses. Examples: - "Translate the phrase from English to French." or "Given a - statement, classify the sentiment as happy, sad or neutral." - - Anything included in this field will take precedence over - message history if the total input size exceeds the model's - ``input_token_limit`` and the input request is truncated. - examples (MutableSequence[google.ai.generativelanguage_v1beta2.types.Example]): - Optional. Examples of what the model should generate. - - This includes both user input and the response that the - model should emulate. - - These ``examples`` are treated identically to conversation - messages except that they take precedence over the history - in ``messages``: If the total input size exceeds the model's - ``input_token_limit`` the input will be truncated. Items - will be dropped from ``messages`` before ``examples``. - messages (MutableSequence[google.ai.generativelanguage_v1beta2.types.Message]): - Required. A snapshot of the recent conversation history - sorted chronologically. - - Turns alternate between two authors. - - If the total input size exceeds the model's - ``input_token_limit`` the input will be truncated: The - oldest items will be dropped from ``messages``. - """ - - context: str = proto.Field( - proto.STRING, - number=1, - ) - examples: MutableSequence['Example'] = proto.RepeatedField( - proto.MESSAGE, - number=2, - message='Example', - ) - messages: MutableSequence['Message'] = proto.RepeatedField( - proto.MESSAGE, - number=3, - message='Message', - ) - - -class Example(proto.Message): - r"""An input/output example used to instruct the Model. - - It demonstrates how the model should respond or format its - response. - - Attributes: - input (google.ai.generativelanguage_v1beta2.types.Message): - Required. An example of an input ``Message`` from the user. - output (google.ai.generativelanguage_v1beta2.types.Message): - Required. An example of what the model should - output given the input. - """ - - input: 'Message' = proto.Field( - proto.MESSAGE, - number=1, - message='Message', - ) - output: 'Message' = proto.Field( - proto.MESSAGE, - number=2, - message='Message', - ) - - -class CountMessageTokensRequest(proto.Message): - r"""Counts the number of tokens in the ``prompt`` sent to a model. - - Models may tokenize text differently, so each model may return a - different ``token_count``. - - Attributes: - model (str): - Required. The model's resource name. This serves as an ID - for the Model to use. - - This name should match a model name returned by the - ``ListModels`` method. - - Format: ``models/{model}`` - prompt (google.ai.generativelanguage_v1beta2.types.MessagePrompt): - Required. The prompt, whose token count is to - be returned. - """ - - model: str = proto.Field( - proto.STRING, - number=1, - ) - prompt: 'MessagePrompt' = proto.Field( - proto.MESSAGE, - number=2, - message='MessagePrompt', - ) - - -class CountMessageTokensResponse(proto.Message): - r"""A response from ``CountMessageTokens``. - - It returns the model's ``token_count`` for the ``prompt``. - - Attributes: - token_count (int): - The number of tokens that the ``model`` tokenizes the - ``prompt`` into. - - Always non-negative. - """ - - token_count: int = proto.Field( - proto.INT32, - number=1, - ) - - -__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py deleted file mode 100644 index d1698c736311..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model.py +++ /dev/null @@ -1,156 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - -from typing import MutableMapping, MutableSequence - -import proto # type: ignore - - -__protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta2', - manifest={ - 'Model', - }, -) - - -class Model(proto.Message): - r"""Information about a Generative Language Model. - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - name (str): - Required. The resource name of the ``Model``. - - Format: ``models/{model}`` with a ``{model}`` naming - convention of: - - - "{base_model_id}-{version}" - - Examples: - - - ``models/chat-bison-001`` - base_model_id (str): - Required. The name of the base model, pass this to the - generation request. - - Examples: - - - ``chat-bison`` - version (str): - Required. The version number of the model. - - This represents the major version - display_name (str): - The human-readable name of the model. E.g. - "Chat Bison". - The name can be up to 128 characters long and - can consist of any UTF-8 characters. - description (str): - A short description of the model. - input_token_limit (int): - Maximum number of input tokens allowed for - this model. - output_token_limit (int): - Maximum number of output tokens available for - this model. - supported_generation_methods (MutableSequence[str]): - The model's supported generation methods. - - The method names are defined as Pascal case strings, such as - ``generateMessage`` which correspond to API methods. - temperature (float): - Controls the randomness of the output. - - Values can range over ``[0.0,1.0]``, inclusive. A value - closer to ``1.0`` will produce responses that are more - varied, while a value closer to ``0.0`` will typically - result in less surprising responses from the model. This - value specifies default to be used by the backend while - making the call to the model. - - This field is a member of `oneof`_ ``_temperature``. - top_p (float): - For Nucleus sampling. - - Nucleus sampling considers the smallest set of tokens whose - probability sum is at least ``top_p``. This value specifies - default to be used by the backend while making the call to - the model. - - This field is a member of `oneof`_ ``_top_p``. - top_k (int): - For Top-k sampling. - - Top-k sampling considers the set of ``top_k`` most probable - tokens. This value specifies default to be used by the - backend while making the call to the model. - - This field is a member of `oneof`_ ``_top_k``. - """ - - name: str = proto.Field( - proto.STRING, - number=1, - ) - base_model_id: str = proto.Field( - proto.STRING, - number=2, - ) - version: str = proto.Field( - proto.STRING, - number=3, - ) - display_name: str = proto.Field( - proto.STRING, - number=4, - ) - description: str = proto.Field( - proto.STRING, - number=5, - ) - input_token_limit: int = proto.Field( - proto.INT32, - number=6, - ) - output_token_limit: int = proto.Field( - proto.INT32, - number=7, - ) - supported_generation_methods: MutableSequence[str] = proto.RepeatedField( - proto.STRING, - number=8, - ) - temperature: float = proto.Field( - proto.FLOAT, - number=9, - optional=True, - ) - top_p: float = proto.Field( - proto.FLOAT, - number=10, - optional=True, - ) - top_k: int = proto.Field( - proto.INT32, - number=11, - optional=True, - ) - - -__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py deleted file mode 100644 index bb10f6ebd82a..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/model_service.py +++ /dev/null @@ -1,114 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - -from typing import MutableMapping, MutableSequence - -import proto # type: ignore - -from google.ai.generativelanguage_v1beta2.types import model - - -__protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta2', - manifest={ - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - }, -) - - -class GetModelRequest(proto.Message): - r"""Request for getting information about a specific Model. - - Attributes: - name (str): - Required. The resource name of the model. - - This name should match a model name returned by the - ``ListModels`` method. - - Format: ``models/{model}`` - """ - - name: str = proto.Field( - proto.STRING, - number=1, - ) - - -class ListModelsRequest(proto.Message): - r"""Request for listing all Models. - - Attributes: - page_size (int): - The maximum number of ``Models`` to return (per page). - - The service may return fewer models. If unspecified, at most - 50 models will be returned per page. This method returns at - most 1000 models per page, even if you pass a larger - page_size. - page_token (str): - A page token, received from a previous ``ListModels`` call. - - Provide the ``page_token`` returned by one request as an - argument to the next request to retrieve the next page. - - When paginating, all other parameters provided to - ``ListModels`` must match the call that provided the page - token. - """ - - page_size: int = proto.Field( - proto.INT32, - number=2, - ) - page_token: str = proto.Field( - proto.STRING, - number=3, - ) - - -class ListModelsResponse(proto.Message): - r"""Response from ``ListModel`` containing a paginated list of Models. - - Attributes: - models (MutableSequence[google.ai.generativelanguage_v1beta2.types.Model]): - The returned Models. - next_page_token (str): - A token, which can be sent as ``page_token`` to retrieve the - next page. - - If this field is omitted, there are no more pages. - """ - - @property - def raw_page(self): - return self - - models: MutableSequence[model.Model] = proto.RepeatedField( - proto.MESSAGE, - number=1, - message=model.Model, - ) - next_page_token: str = proto.Field( - proto.STRING, - number=2, - ) - - -__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py deleted file mode 100644 index 990acf3f4dd2..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/safety.py +++ /dev/null @@ -1,247 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - -from typing import MutableMapping, MutableSequence - -import proto # type: ignore - - -__protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta2', - manifest={ - 'HarmCategory', - 'ContentFilter', - 'SafetyFeedback', - 'SafetyRating', - 'SafetySetting', - }, -) - - -class HarmCategory(proto.Enum): - r"""The category of a rating. - - These categories cover various kinds of harms that developers - may wish to adjust. - - Values: - HARM_CATEGORY_UNSPECIFIED (0): - Category is unspecified. - HARM_CATEGORY_DEROGATORY (1): - Negative or harmful comments targeting - identity and/or protected attribute. - HARM_CATEGORY_TOXICITY (2): - Content that is rude, disrepspectful, or - profane. - HARM_CATEGORY_VIOLENCE (3): - Describes scenarios depictng violence against - an individual or group, or general descriptions - of gore. - HARM_CATEGORY_SEXUAL (4): - Contains references to sexual acts or other - lewd content. - HARM_CATEGORY_MEDICAL (5): - Promotes unchecked medical advice. - HARM_CATEGORY_DANGEROUS (6): - Dangerous content that promotes, facilitates, - or encourages harmful acts. - """ - HARM_CATEGORY_UNSPECIFIED = 0 - HARM_CATEGORY_DEROGATORY = 1 - HARM_CATEGORY_TOXICITY = 2 - HARM_CATEGORY_VIOLENCE = 3 - HARM_CATEGORY_SEXUAL = 4 - HARM_CATEGORY_MEDICAL = 5 - HARM_CATEGORY_DANGEROUS = 6 - - -class ContentFilter(proto.Message): - r"""Content filtering metadata associated with processing a - single request. - ContentFilter contains a reason and an optional supporting - string. The reason may be unspecified. - - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - reason (google.ai.generativelanguage_v1beta2.types.ContentFilter.BlockedReason): - The reason content was blocked during request - processing. - message (str): - A string that describes the filtering - behavior in more detail. - - This field is a member of `oneof`_ ``_message``. - """ - class BlockedReason(proto.Enum): - r"""A list of reasons why content may have been blocked. - - Values: - BLOCKED_REASON_UNSPECIFIED (0): - A blocked reason was not specified. - SAFETY (1): - Content was blocked by safety settings. - OTHER (2): - Content was blocked, but the reason is - uncategorized. - """ - BLOCKED_REASON_UNSPECIFIED = 0 - SAFETY = 1 - OTHER = 2 - - reason: BlockedReason = proto.Field( - proto.ENUM, - number=1, - enum=BlockedReason, - ) - message: str = proto.Field( - proto.STRING, - number=2, - optional=True, - ) - - -class SafetyFeedback(proto.Message): - r"""Safety feedback for an entire request. - - This field is populated if content in the input and/or response - is blocked due to safety settings. SafetyFeedback may not exist - for every HarmCategory. Each SafetyFeedback will return the - safety settings used by the request as well as the lowest - HarmProbability that should be allowed in order to return a - result. - - Attributes: - rating (google.ai.generativelanguage_v1beta2.types.SafetyRating): - Safety rating evaluated from content. - setting (google.ai.generativelanguage_v1beta2.types.SafetySetting): - Safety settings applied to the request. - """ - - rating: 'SafetyRating' = proto.Field( - proto.MESSAGE, - number=1, - message='SafetyRating', - ) - setting: 'SafetySetting' = proto.Field( - proto.MESSAGE, - number=2, - message='SafetySetting', - ) - - -class SafetyRating(proto.Message): - r"""Safety rating for a piece of content. - - The safety rating contains the category of harm and the harm - probability level in that category for a piece of content. - Content is classified for safety across a number of harm - categories and the probability of the harm classification is - included here. - - Attributes: - category (google.ai.generativelanguage_v1beta2.types.HarmCategory): - Required. The category for this rating. - probability (google.ai.generativelanguage_v1beta2.types.SafetyRating.HarmProbability): - Required. The probability of harm for this - content. - """ - class HarmProbability(proto.Enum): - r"""The probability that a piece of content is harmful. - - The classification system gives the probability of the content - being unsafe. This does not indicate the severity of harm for a - piece of content. - - Values: - HARM_PROBABILITY_UNSPECIFIED (0): - Probability is unspecified. - NEGLIGIBLE (1): - Content has a negligible chance of being - unsafe. - LOW (2): - Content has a low chance of being unsafe. - MEDIUM (3): - Content has a medium chance of being unsafe. - HIGH (4): - Content has a high chance of being unsafe. - """ - HARM_PROBABILITY_UNSPECIFIED = 0 - NEGLIGIBLE = 1 - LOW = 2 - MEDIUM = 3 - HIGH = 4 - - category: 'HarmCategory' = proto.Field( - proto.ENUM, - number=3, - enum='HarmCategory', - ) - probability: HarmProbability = proto.Field( - proto.ENUM, - number=4, - enum=HarmProbability, - ) - - -class SafetySetting(proto.Message): - r"""Safety setting, affecting the safety-blocking behavior. - - Passing a safety setting for a category changes the allowed - proability that content is blocked. - - Attributes: - category (google.ai.generativelanguage_v1beta2.types.HarmCategory): - Required. The category for this setting. - threshold (google.ai.generativelanguage_v1beta2.types.SafetySetting.HarmBlockThreshold): - Required. Controls the probability threshold - at which harm is blocked. - """ - class HarmBlockThreshold(proto.Enum): - r"""Block at and beyond a specified harm probability. - - Values: - HARM_BLOCK_THRESHOLD_UNSPECIFIED (0): - Threshold is unspecified. - BLOCK_LOW_AND_ABOVE (1): - Content with NEGLIGIBLE will be allowed. - BLOCK_MEDIUM_AND_ABOVE (2): - Content with NEGLIGIBLE and LOW will be - allowed. - BLOCK_ONLY_HIGH (3): - Content with NEGLIGIBLE, LOW, and MEDIUM will - be allowed. - """ - HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0 - BLOCK_LOW_AND_ABOVE = 1 - BLOCK_MEDIUM_AND_ABOVE = 2 - BLOCK_ONLY_HIGH = 3 - - category: 'HarmCategory' = proto.Field( - proto.ENUM, - number=3, - enum='HarmCategory', - ) - threshold: HarmBlockThreshold = proto.Field( - proto.ENUM, - number=4, - enum=HarmBlockThreshold, - ) - - -__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py deleted file mode 100644 index 572f3c5392b2..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/types/text_service.py +++ /dev/null @@ -1,333 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from __future__ import annotations - -from typing import MutableMapping, MutableSequence - -import proto # type: ignore - -from google.ai.generativelanguage_v1beta2.types import citation -from google.ai.generativelanguage_v1beta2.types import safety - - -__protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta2', - manifest={ - 'GenerateTextRequest', - 'GenerateTextResponse', - 'TextPrompt', - 'TextCompletion', - 'EmbedTextRequest', - 'EmbedTextResponse', - 'Embedding', - }, -) - - -class GenerateTextRequest(proto.Message): - r"""Request to generate a text completion response from the - model. - - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - model (str): - Required. The model name to use with the - format name=models/{model}. - prompt (google.ai.generativelanguage_v1beta2.types.TextPrompt): - Required. The free-form input text given to - the model as a prompt. - Given a prompt, the model will generate a - TextCompletion response it predicts as the - completion of the input text. - temperature (float): - Controls the randomness of the output. Note: The default - value varies by model, see the ``Model.temperature`` - attribute of the ``Model`` returned the ``getModel`` - function. - - Values can range from [0.0,1.0], inclusive. A value closer - to 1.0 will produce responses that are more varied and - creative, while a value closer to 0.0 will typically result - in more straightforward responses from the model. - - This field is a member of `oneof`_ ``_temperature``. - candidate_count (int): - Number of generated responses to return. - - This value must be between [1, 8], inclusive. If unset, this - will default to 1. - - This field is a member of `oneof`_ ``_candidate_count``. - max_output_tokens (int): - The maximum number of tokens to include in a - candidate. - If unset, this will default to 64. - - This field is a member of `oneof`_ ``_max_output_tokens``. - top_p (float): - The maximum cumulative probability of tokens to consider - when sampling. - - The model uses combined Top-k and nucleus sampling. - - Tokens are sorted based on their assigned probabilities so - that only the most liekly tokens are considered. Top-k - sampling directly limits the maximum number of tokens to - consider, while Nucleus sampling limits number of tokens - based on the cumulative probability. - - Note: The default value varies by model, see the - ``Model.top_p`` attribute of the ``Model`` returned the - ``getModel`` function. - - This field is a member of `oneof`_ ``_top_p``. - top_k (int): - The maximum number of tokens to consider when sampling. - - The model uses combined Top-k and nucleus sampling. - - Top-k sampling considers the set of ``top_k`` most probable - tokens. Defaults to 40. - - Note: The default value varies by model, see the - ``Model.top_k`` attribute of the ``Model`` returned the - ``getModel`` function. - - This field is a member of `oneof`_ ``_top_k``. - safety_settings (MutableSequence[google.ai.generativelanguage_v1beta2.types.SafetySetting]): - A list of unique ``SafetySetting`` instances for blocking - unsafe content. - - that will be enforced on the ``GenerateTextRequest.prompt`` - and ``GenerateTextResponse.candidates``. There should not be - more than one setting for each ``SafetyCategory`` type. The - API will block any prompts and responses that fail to meet - the thresholds set by these settings. This list overrides - the default settings for each ``SafetyCategory`` specified - in the safety_settings. If there is no ``SafetySetting`` for - a given ``SafetyCategory`` provided in the list, the API - will use the default safety setting for that category. - stop_sequences (MutableSequence[str]): - The set of character sequences (up to 5) that - will stop output generation. If specified, the - API will stop at the first appearance of a stop - sequence. The stop sequence will not be included - as part of the response. - """ - - model: str = proto.Field( - proto.STRING, - number=1, - ) - prompt: 'TextPrompt' = proto.Field( - proto.MESSAGE, - number=2, - message='TextPrompt', - ) - temperature: float = proto.Field( - proto.FLOAT, - number=3, - optional=True, - ) - candidate_count: int = proto.Field( - proto.INT32, - number=4, - optional=True, - ) - max_output_tokens: int = proto.Field( - proto.INT32, - number=5, - optional=True, - ) - top_p: float = proto.Field( - proto.FLOAT, - number=6, - optional=True, - ) - top_k: int = proto.Field( - proto.INT32, - number=7, - optional=True, - ) - safety_settings: MutableSequence[safety.SafetySetting] = proto.RepeatedField( - proto.MESSAGE, - number=8, - message=safety.SafetySetting, - ) - stop_sequences: MutableSequence[str] = proto.RepeatedField( - proto.STRING, - number=9, - ) - - -class GenerateTextResponse(proto.Message): - r"""The response from the model, including candidate completions. - - Attributes: - candidates (MutableSequence[google.ai.generativelanguage_v1beta2.types.TextCompletion]): - Candidate responses from the model. - filters (MutableSequence[google.ai.generativelanguage_v1beta2.types.ContentFilter]): - A set of content filtering metadata for the prompt and - response text. - - This indicates which ``SafetyCategory``\ (s) blocked a - candidate from this response, the lowest ``HarmProbability`` - that triggered a block, and the HarmThreshold setting for - that category. This indicates the smallest change to the - ``SafetySettings`` that would be necessary to unblock at - least 1 response. - - The blocking is configured by the ``SafetySettings`` in the - request (or the default ``SafetySettings`` of the API). - safety_feedback (MutableSequence[google.ai.generativelanguage_v1beta2.types.SafetyFeedback]): - Returns any safety feedback related to - content filtering. - """ - - candidates: MutableSequence['TextCompletion'] = proto.RepeatedField( - proto.MESSAGE, - number=1, - message='TextCompletion', - ) - filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( - proto.MESSAGE, - number=3, - message=safety.ContentFilter, - ) - safety_feedback: MutableSequence[safety.SafetyFeedback] = proto.RepeatedField( - proto.MESSAGE, - number=4, - message=safety.SafetyFeedback, - ) - - -class TextPrompt(proto.Message): - r"""Text given to the model as a prompt. - - The Model will use this TextPrompt to Generate a text - completion. - - Attributes: - text (str): - Required. The prompt text. - """ - - text: str = proto.Field( - proto.STRING, - number=1, - ) - - -class TextCompletion(proto.Message): - r"""Output text returned from a model. - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - output (str): - Output only. The generated text returned from - the model. - safety_ratings (MutableSequence[google.ai.generativelanguage_v1beta2.types.SafetyRating]): - Ratings for the safety of a response. - - There is at most one rating per category. - citation_metadata (google.ai.generativelanguage_v1beta2.types.CitationMetadata): - Output only. Citation information for model-generated - ``output`` in this ``TextCompletion``. - - This field may be populated with attribution information for - any text included in the ``output``. - - This field is a member of `oneof`_ ``_citation_metadata``. - """ - - output: str = proto.Field( - proto.STRING, - number=1, - ) - safety_ratings: MutableSequence[safety.SafetyRating] = proto.RepeatedField( - proto.MESSAGE, - number=2, - message=safety.SafetyRating, - ) - citation_metadata: citation.CitationMetadata = proto.Field( - proto.MESSAGE, - number=3, - optional=True, - message=citation.CitationMetadata, - ) - - -class EmbedTextRequest(proto.Message): - r"""Request to get a text embedding from the model. - - Attributes: - model (str): - Required. The model name to use with the - format model=models/{model}. - text (str): - Required. The free-form input text that the - model will turn into an embedding. - """ - - model: str = proto.Field( - proto.STRING, - number=1, - ) - text: str = proto.Field( - proto.STRING, - number=2, - ) - - -class EmbedTextResponse(proto.Message): - r"""The response to a EmbedTextRequest. - - .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields - - Attributes: - embedding (google.ai.generativelanguage_v1beta2.types.Embedding): - Output only. The embedding generated from the - input text. - - This field is a member of `oneof`_ ``_embedding``. - """ - - embedding: 'Embedding' = proto.Field( - proto.MESSAGE, - number=1, - optional=True, - message='Embedding', - ) - - -class Embedding(proto.Message): - r"""A list of floats representing the embedding. - - Attributes: - value (MutableSequence[float]): - The embedding values. - """ - - value: MutableSequence[float] = proto.RepeatedField( - proto.FLOAT, - number=1, - ) - - -__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini b/owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini deleted file mode 100644 index 574c5aed394b..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/mypy.ini +++ /dev/null @@ -1,3 +0,0 @@ -[mypy] -python_version = 3.7 -namespace_packages = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py deleted file mode 100644 index 96375ae41831..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/noxfile.py +++ /dev/null @@ -1,184 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -import pathlib -import shutil -import subprocess -import sys - - -import nox # type: ignore - -ALL_PYTHON = [ - "3.7", - "3.8", - "3.9", - "3.10", - "3.11", -] - -CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() - -LOWER_BOUND_CONSTRAINTS_FILE = CURRENT_DIRECTORY / "constraints.txt" -PACKAGE_NAME = subprocess.check_output([sys.executable, "setup.py", "--name"], encoding="utf-8") - -BLACK_VERSION = "black==22.3.0" -BLACK_PATHS = ["docs", "google", "tests", "samples", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.11" - -nox.sessions = [ - "unit", - "cover", - "mypy", - "check_lower_bounds" - # exclude update_lower_bounds from default - "docs", - "blacken", - "lint", - "lint_setup_py", -] - -@nox.session(python=ALL_PYTHON) -def unit(session): - """Run the unit test suite.""" - - session.install('coverage', 'pytest', 'pytest-cov', 'pytest-asyncio', 'asyncmock; python_version < "3.8"') - session.install('-e', '.') - - session.run( - 'py.test', - '--quiet', - '--cov=google/ai/generativelanguage_v1beta2/', - '--cov=tests/', - '--cov-config=.coveragerc', - '--cov-report=term', - '--cov-report=html', - os.path.join('tests', 'unit', ''.join(session.posargs)) - ) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def cover(session): - """Run the final coverage report. - This outputs the coverage report aggregating coverage from the unit - test runs (not system test runs), and then erases coverage data. - """ - session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=100") - - session.run("coverage", "erase") - - -@nox.session(python=ALL_PYTHON) -def mypy(session): - """Run the type checker.""" - session.install( - 'mypy', - 'types-requests', - 'types-protobuf' - ) - session.install('.') - session.run( - 'mypy', - '--explicit-package-bases', - 'google', - ) - - -@nox.session -def update_lower_bounds(session): - """Update lower bounds in constraints.txt to match setup.py""" - session.install('google-cloud-testutils') - session.install('.') - - session.run( - 'lower-bound-checker', - 'update', - '--package-name', - PACKAGE_NAME, - '--constraints-file', - str(LOWER_BOUND_CONSTRAINTS_FILE), - ) - - -@nox.session -def check_lower_bounds(session): - """Check lower bounds in setup.py are reflected in constraints file""" - session.install('google-cloud-testutils') - session.install('.') - - session.run( - 'lower-bound-checker', - 'check', - '--package-name', - PACKAGE_NAME, - '--constraints-file', - str(LOWER_BOUND_CONSTRAINTS_FILE), - ) - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def docs(session): - """Build the docs for this library.""" - - session.install("-e", ".") - session.install("sphinx==7.0.1", "alabaster", "recommonmark") - - shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) - session.run( - "sphinx-build", - "-W", # warnings as errors - "-T", # show full traceback on exception - "-N", # no colors - "-b", - "html", - "-d", - os.path.join("docs", "_build", "doctrees", ""), - os.path.join("docs", ""), - os.path.join("docs", "_build", "html", ""), - ) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def lint(session): - """Run linters. - - Returns a failure if the linters find linting errors or sufficiently - serious code quality issues. - """ - session.install("flake8", BLACK_VERSION) - session.run( - "black", - "--check", - *BLACK_PATHS, - ) - session.run("flake8", "google", "tests", "samples") - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def blacken(session): - """Run black. Format code to uniform standard.""" - session.install(BLACK_VERSION) - session.run( - "black", - *BLACK_PATHS, - ) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def lint_setup_py(session): - """Verify that setup.py is valid (including RST check).""" - session.install("docutils", "pygments") - session.run("python", "setup.py", "check", "--restructuredtext", "--strict") diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py deleted file mode 100644 index 1b587e44368d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for CountMessageTokens -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_async] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -async def sample_count_message_tokens(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceAsyncClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.CountMessageTokensRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = await client.count_message_tokens(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py deleted file mode 100644 index 590d967fdfa6..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for CountMessageTokens -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_sync] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -def sample_count_message_tokens(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.CountMessageTokensRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = client.count_message_tokens(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py deleted file mode 100644 index 22848d706b77..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for GenerateMessage -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_async] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -async def sample_generate_message(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceAsyncClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.GenerateMessageRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = await client.generate_message(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py deleted file mode 100644 index 30106bdee93b..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for GenerateMessage -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_sync] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -def sample_generate_message(): - # Create a client - client = generativelanguage_v1beta2.DiscussServiceClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.MessagePrompt() - prompt.messages.content = "content_value" - - request = generativelanguage_v1beta2.GenerateMessageRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = client.generate_message(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py deleted file mode 100644 index 1eb30ff00aaa..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_async.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for GetModel -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_ModelService_GetModel_async] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -async def sample_get_model(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceAsyncClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.GetModelRequest( - name="name_value", - ) - - # Make the request - response = await client.get_model(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_ModelService_GetModel_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py deleted file mode 100644 index 84eda9615b78..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_get_model_sync.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for GetModel -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_ModelService_GetModel_sync] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -def sample_get_model(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.GetModelRequest( - name="name_value", - ) - - # Make the request - response = client.get_model(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_ModelService_GetModel_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py deleted file mode 100644 index 7d21ae65d7e6..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_async.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for ListModels -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_ModelService_ListModels_async] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -async def sample_list_models(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceAsyncClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.ListModelsRequest( - ) - - # Make the request - page_result = client.list_models(request=request) - - # Handle the response - async for response in page_result: - print(response) - -# [END generativelanguage_v1beta2_generated_ModelService_ListModels_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py deleted file mode 100644 index e94decf56a96..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_model_service_list_models_sync.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for ListModels -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_ModelService_ListModels_sync] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -def sample_list_models(): - # Create a client - client = generativelanguage_v1beta2.ModelServiceClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.ListModelsRequest( - ) - - # Make the request - page_result = client.list_models(request=request) - - # Handle the response - for response in page_result: - print(response) - -# [END generativelanguage_v1beta2_generated_ModelService_ListModels_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py deleted file mode 100644 index d970ee8f589c..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_async.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for EmbedText -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_TextService_EmbedText_async] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -async def sample_embed_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceAsyncClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.EmbedTextRequest( - model="model_value", - text="text_value", - ) - - # Make the request - response = await client.embed_text(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_TextService_EmbedText_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py deleted file mode 100644 index c00795a1f795..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_embed_text_sync.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for EmbedText -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_TextService_EmbedText_sync] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -def sample_embed_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceClient() - - # Initialize request argument(s) - request = generativelanguage_v1beta2.EmbedTextRequest( - model="model_value", - text="text_value", - ) - - # Make the request - response = client.embed_text(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_TextService_EmbedText_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py deleted file mode 100644 index f41f480f205c..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_async.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for GenerateText -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_TextService_GenerateText_async] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -async def sample_generate_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceAsyncClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.TextPrompt() - prompt.text = "text_value" - - request = generativelanguage_v1beta2.GenerateTextRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = await client.generate_text(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_TextService_GenerateText_async] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py deleted file mode 100644 index 900ed0003aeb..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/generativelanguage_v1beta2_generated_text_service_generate_text_sync.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# Generated code. DO NOT EDIT! -# -# Snippet for GenerateText -# NOTE: This snippet has been automatically generated for illustrative purposes only. -# It may require modifications to work in your environment. - -# To install the latest published package dependency, execute the following: -# python3 -m pip install google-ai-generativelanguage - - -# [START generativelanguage_v1beta2_generated_TextService_GenerateText_sync] -# This snippet has been automatically generated and should be regarded as a -# code template only. -# It will require modifications to work: -# - It may require correct/in-range values for request initialization. -# - It may require specifying regional endpoints when creating the service -# client as shown in: -# https://googleapis.dev/python/google-api-core/latest/client_options.html -from google.ai import generativelanguage_v1beta2 - - -def sample_generate_text(): - # Create a client - client = generativelanguage_v1beta2.TextServiceClient() - - # Initialize request argument(s) - prompt = generativelanguage_v1beta2.TextPrompt() - prompt.text = "text_value" - - request = generativelanguage_v1beta2.GenerateTextRequest( - model="model_value", - prompt=prompt, - ) - - # Make the request - response = client.generate_text(request=request) - - # Handle the response - print(response) - -# [END generativelanguage_v1beta2_generated_TextService_GenerateText_sync] diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json b/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json deleted file mode 100644 index 5b7d0a0509b4..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json +++ /dev/null @@ -1,1093 +0,0 @@ -{ - "clientLibrary": { - "apis": [ - { - "id": "google.ai.generativelanguage.v1beta2", - "version": "v1beta2" - } - ], - "language": "PYTHON", - "name": "google-ai-generativelanguage", - "version": "0.1.0" - }, - "snippets": [ - { - "canonical": true, - "clientMethod": { - "async": true, - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient", - "shortName": "DiscussServiceAsyncClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient.count_message_tokens", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.CountMessageTokens", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", - "shortName": "DiscussService" - }, - "shortName": "CountMessageTokens" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "prompt", - "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse", - "shortName": "count_message_tokens" - }, - "description": "Sample for CountMessageTokens", - "file": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_async", - "segments": [ - { - "end": 55, - "start": 27, - "type": "FULL" - }, - { - "end": 55, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 49, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 52, - "start": 50, - "type": "REQUEST_EXECUTION" - }, - { - "end": 56, - "start": 53, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_async.py" - }, - { - "canonical": true, - "clientMethod": { - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient", - "shortName": "DiscussServiceClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient.count_message_tokens", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.CountMessageTokens", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", - "shortName": "DiscussService" - }, - "shortName": "CountMessageTokens" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "prompt", - "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.CountMessageTokensResponse", - "shortName": "count_message_tokens" - }, - "description": "Sample for CountMessageTokens", - "file": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_DiscussService_CountMessageTokens_sync", - "segments": [ - { - "end": 55, - "start": 27, - "type": "FULL" - }, - { - "end": 55, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 49, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 52, - "start": 50, - "type": "REQUEST_EXECUTION" - }, - { - "end": 56, - "start": 53, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_discuss_service_count_message_tokens_sync.py" - }, - { - "canonical": true, - "clientMethod": { - "async": true, - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient", - "shortName": "DiscussServiceAsyncClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceAsyncClient.generate_message", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.GenerateMessage", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", - "shortName": "DiscussService" - }, - "shortName": "GenerateMessage" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "prompt", - "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" - }, - { - "name": "temperature", - "type": "float" - }, - { - "name": "candidate_count", - "type": "int" - }, - { - "name": "top_p", - "type": "float" - }, - { - "name": "top_k", - "type": "int" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse", - "shortName": "generate_message" - }, - "description": "Sample for GenerateMessage", - "file": "generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_async", - "segments": [ - { - "end": 55, - "start": 27, - "type": "FULL" - }, - { - "end": 55, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 49, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 52, - "start": 50, - "type": "REQUEST_EXECUTION" - }, - { - "end": 56, - "start": 53, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_discuss_service_generate_message_async.py" - }, - { - "canonical": true, - "clientMethod": { - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient", - "shortName": "DiscussServiceClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.DiscussServiceClient.generate_message", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService.GenerateMessage", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.DiscussService", - "shortName": "DiscussService" - }, - "shortName": "GenerateMessage" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.GenerateMessageRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "prompt", - "type": "google.ai.generativelanguage_v1beta2.types.MessagePrompt" - }, - { - "name": "temperature", - "type": "float" - }, - { - "name": "candidate_count", - "type": "int" - }, - { - "name": "top_p", - "type": "float" - }, - { - "name": "top_k", - "type": "int" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse", - "shortName": "generate_message" - }, - "description": "Sample for GenerateMessage", - "file": "generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_DiscussService_GenerateMessage_sync", - "segments": [ - { - "end": 55, - "start": 27, - "type": "FULL" - }, - { - "end": 55, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 49, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 52, - "start": 50, - "type": "REQUEST_EXECUTION" - }, - { - "end": 56, - "start": 53, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_discuss_service_generate_message_sync.py" - }, - { - "canonical": true, - "clientMethod": { - "async": true, - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient", - "shortName": "ModelServiceAsyncClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient.get_model", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService.GetModel", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService", - "shortName": "ModelService" - }, - "shortName": "GetModel" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.GetModelRequest" - }, - { - "name": "name", - "type": "str" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.Model", - "shortName": "get_model" - }, - "description": "Sample for GetModel", - "file": "generativelanguage_v1beta2_generated_model_service_get_model_async.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_ModelService_GetModel_async", - "segments": [ - { - "end": 51, - "start": 27, - "type": "FULL" - }, - { - "end": 51, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 45, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 48, - "start": 46, - "type": "REQUEST_EXECUTION" - }, - { - "end": 52, - "start": 49, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_model_service_get_model_async.py" - }, - { - "canonical": true, - "clientMethod": { - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient", - "shortName": "ModelServiceClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient.get_model", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService.GetModel", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService", - "shortName": "ModelService" - }, - "shortName": "GetModel" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.GetModelRequest" - }, - { - "name": "name", - "type": "str" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.Model", - "shortName": "get_model" - }, - "description": "Sample for GetModel", - "file": "generativelanguage_v1beta2_generated_model_service_get_model_sync.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_ModelService_GetModel_sync", - "segments": [ - { - "end": 51, - "start": 27, - "type": "FULL" - }, - { - "end": 51, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 45, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 48, - "start": 46, - "type": "REQUEST_EXECUTION" - }, - { - "end": 52, - "start": 49, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_model_service_get_model_sync.py" - }, - { - "canonical": true, - "clientMethod": { - "async": true, - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient", - "shortName": "ModelServiceAsyncClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceAsyncClient.list_models", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService.ListModels", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService", - "shortName": "ModelService" - }, - "shortName": "ListModels" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.ListModelsRequest" - }, - { - "name": "page_size", - "type": "int" - }, - { - "name": "page_token", - "type": "str" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsAsyncPager", - "shortName": "list_models" - }, - "description": "Sample for ListModels", - "file": "generativelanguage_v1beta2_generated_model_service_list_models_async.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_ModelService_ListModels_async", - "segments": [ - { - "end": 51, - "start": 27, - "type": "FULL" - }, - { - "end": 51, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 44, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 47, - "start": 45, - "type": "REQUEST_EXECUTION" - }, - { - "end": 52, - "start": 48, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_model_service_list_models_async.py" - }, - { - "canonical": true, - "clientMethod": { - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient", - "shortName": "ModelServiceClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.ModelServiceClient.list_models", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService.ListModels", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.ModelService", - "shortName": "ModelService" - }, - "shortName": "ListModels" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.ListModelsRequest" - }, - { - "name": "page_size", - "type": "int" - }, - { - "name": "page_token", - "type": "str" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.services.model_service.pagers.ListModelsPager", - "shortName": "list_models" - }, - "description": "Sample for ListModels", - "file": "generativelanguage_v1beta2_generated_model_service_list_models_sync.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_ModelService_ListModels_sync", - "segments": [ - { - "end": 51, - "start": 27, - "type": "FULL" - }, - { - "end": 51, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 44, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 47, - "start": 45, - "type": "REQUEST_EXECUTION" - }, - { - "end": 52, - "start": 48, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_model_service_list_models_sync.py" - }, - { - "canonical": true, - "clientMethod": { - "async": true, - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient", - "shortName": "TextServiceAsyncClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient.embed_text", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService.EmbedText", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService", - "shortName": "TextService" - }, - "shortName": "EmbedText" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.EmbedTextRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "text", - "type": "str" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.EmbedTextResponse", - "shortName": "embed_text" - }, - "description": "Sample for EmbedText", - "file": "generativelanguage_v1beta2_generated_text_service_embed_text_async.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_TextService_EmbedText_async", - "segments": [ - { - "end": 52, - "start": 27, - "type": "FULL" - }, - { - "end": 52, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 46, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 49, - "start": 47, - "type": "REQUEST_EXECUTION" - }, - { - "end": 53, - "start": 50, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_text_service_embed_text_async.py" - }, - { - "canonical": true, - "clientMethod": { - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient", - "shortName": "TextServiceClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient.embed_text", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService.EmbedText", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService", - "shortName": "TextService" - }, - "shortName": "EmbedText" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.EmbedTextRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "text", - "type": "str" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.EmbedTextResponse", - "shortName": "embed_text" - }, - "description": "Sample for EmbedText", - "file": "generativelanguage_v1beta2_generated_text_service_embed_text_sync.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_TextService_EmbedText_sync", - "segments": [ - { - "end": 52, - "start": 27, - "type": "FULL" - }, - { - "end": 52, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 46, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 49, - "start": 47, - "type": "REQUEST_EXECUTION" - }, - { - "end": 53, - "start": 50, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_text_service_embed_text_sync.py" - }, - { - "canonical": true, - "clientMethod": { - "async": true, - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient", - "shortName": "TextServiceAsyncClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceAsyncClient.generate_text", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService.GenerateText", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService", - "shortName": "TextService" - }, - "shortName": "GenerateText" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.GenerateTextRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "prompt", - "type": "google.ai.generativelanguage_v1beta2.types.TextPrompt" - }, - { - "name": "temperature", - "type": "float" - }, - { - "name": "candidate_count", - "type": "int" - }, - { - "name": "max_output_tokens", - "type": "int" - }, - { - "name": "top_p", - "type": "float" - }, - { - "name": "top_k", - "type": "int" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateTextResponse", - "shortName": "generate_text" - }, - "description": "Sample for GenerateText", - "file": "generativelanguage_v1beta2_generated_text_service_generate_text_async.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_TextService_GenerateText_async", - "segments": [ - { - "end": 55, - "start": 27, - "type": "FULL" - }, - { - "end": 55, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 49, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 52, - "start": 50, - "type": "REQUEST_EXECUTION" - }, - { - "end": 56, - "start": 53, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_text_service_generate_text_async.py" - }, - { - "canonical": true, - "clientMethod": { - "client": { - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient", - "shortName": "TextServiceClient" - }, - "fullName": "google.ai.generativelanguage_v1beta2.TextServiceClient.generate_text", - "method": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService.GenerateText", - "service": { - "fullName": "google.ai.generativelanguage.v1beta2.TextService", - "shortName": "TextService" - }, - "shortName": "GenerateText" - }, - "parameters": [ - { - "name": "request", - "type": "google.ai.generativelanguage_v1beta2.types.GenerateTextRequest" - }, - { - "name": "model", - "type": "str" - }, - { - "name": "prompt", - "type": "google.ai.generativelanguage_v1beta2.types.TextPrompt" - }, - { - "name": "temperature", - "type": "float" - }, - { - "name": "candidate_count", - "type": "int" - }, - { - "name": "max_output_tokens", - "type": "int" - }, - { - "name": "top_p", - "type": "float" - }, - { - "name": "top_k", - "type": "int" - }, - { - "name": "retry", - "type": "google.api_core.retry.Retry" - }, - { - "name": "timeout", - "type": "float" - }, - { - "name": "metadata", - "type": "Sequence[Tuple[str, str]" - } - ], - "resultType": "google.ai.generativelanguage_v1beta2.types.GenerateTextResponse", - "shortName": "generate_text" - }, - "description": "Sample for GenerateText", - "file": "generativelanguage_v1beta2_generated_text_service_generate_text_sync.py", - "language": "PYTHON", - "origin": "API_DEFINITION", - "regionTag": "generativelanguage_v1beta2_generated_TextService_GenerateText_sync", - "segments": [ - { - "end": 55, - "start": 27, - "type": "FULL" - }, - { - "end": 55, - "start": 27, - "type": "SHORT" - }, - { - "end": 40, - "start": 38, - "type": "CLIENT_INITIALIZATION" - }, - { - "end": 49, - "start": 41, - "type": "REQUEST_INITIALIZATION" - }, - { - "end": 52, - "start": 50, - "type": "REQUEST_EXECUTION" - }, - { - "end": 56, - "start": 53, - "type": "RESPONSE_HANDLING" - } - ], - "title": "generativelanguage_v1beta2_generated_text_service_generate_text_sync.py" - } - ] -} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py deleted file mode 100644 index 0c638051d5bf..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/scripts/fixup_generativelanguage_v1beta2_keywords.py +++ /dev/null @@ -1,181 +0,0 @@ -#! /usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import argparse -import os -import libcst as cst -import pathlib -import sys -from typing import (Any, Callable, Dict, List, Sequence, Tuple) - - -def partition( - predicate: Callable[[Any], bool], - iterator: Sequence[Any] -) -> Tuple[List[Any], List[Any]]: - """A stable, out-of-place partition.""" - results = ([], []) - - for i in iterator: - results[int(predicate(i))].append(i) - - # Returns trueList, falseList - return results[1], results[0] - - -class generativelanguageCallTransformer(cst.CSTTransformer): - CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') - METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { - 'count_message_tokens': ('model', 'prompt', ), - 'embed_text': ('model', 'text', ), - 'generate_message': ('model', 'prompt', 'temperature', 'candidate_count', 'top_p', 'top_k', ), - 'generate_text': ('model', 'prompt', 'temperature', 'candidate_count', 'max_output_tokens', 'top_p', 'top_k', 'safety_settings', 'stop_sequences', ), - 'get_model': ('name', ), - 'list_models': ('page_size', 'page_token', ), - } - - def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: - try: - key = original.func.attr.value - kword_params = self.METHOD_TO_PARAMS[key] - except (AttributeError, KeyError): - # Either not a method from the API or too convoluted to be sure. - return updated - - # If the existing code is valid, keyword args come after positional args. - # Therefore, all positional args must map to the first parameters. - args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) - if any(k.keyword.value == "request" for k in kwargs): - # We've already fixed this file, don't fix it again. - return updated - - kwargs, ctrl_kwargs = partition( - lambda a: a.keyword.value not in self.CTRL_PARAMS, - kwargs - ) - - args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] - ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) - for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) - - request_arg = cst.Arg( - value=cst.Dict([ - cst.DictElement( - cst.SimpleString("'{}'".format(name)), -cst.Element(value=arg.value) - ) - # Note: the args + kwargs looks silly, but keep in mind that - # the control parameters had to be stripped out, and that - # those could have been passed positionally or by keyword. - for name, arg in zip(kword_params, args + kwargs)]), - keyword=cst.Name("request") - ) - - return updated.with_changes( - args=[request_arg] + ctrl_kwargs - ) - - -def fix_files( - in_dir: pathlib.Path, - out_dir: pathlib.Path, - *, - transformer=generativelanguageCallTransformer(), -): - """Duplicate the input dir to the output dir, fixing file method calls. - - Preconditions: - * in_dir is a real directory - * out_dir is a real, empty directory - """ - pyfile_gen = ( - pathlib.Path(os.path.join(root, f)) - for root, _, files in os.walk(in_dir) - for f in files if os.path.splitext(f)[1] == ".py" - ) - - for fpath in pyfile_gen: - with open(fpath, 'r') as f: - src = f.read() - - # Parse the code and insert method call fixes. - tree = cst.parse_module(src) - updated = tree.visit(transformer) - - # Create the path and directory structure for the new file. - updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) - updated_path.parent.mkdir(parents=True, exist_ok=True) - - # Generate the updated source file at the corresponding path. - with open(updated_path, 'w') as f: - f.write(updated.code) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser( - description="""Fix up source that uses the generativelanguage client library. - -The existing sources are NOT overwritten but are copied to output_dir with changes made. - -Note: This tool operates at a best-effort level at converting positional - parameters in client method calls to keyword based parameters. - Cases where it WILL FAIL include - A) * or ** expansion in a method call. - B) Calls via function or method alias (includes free function calls) - C) Indirect or dispatched calls (e.g. the method is looked up dynamically) - - These all constitute false negatives. The tool will also detect false - positives when an API method shares a name with another method. -""") - parser.add_argument( - '-d', - '--input-directory', - required=True, - dest='input_dir', - help='the input directory to walk for python files to fix up', - ) - parser.add_argument( - '-o', - '--output-directory', - required=True, - dest='output_dir', - help='the directory to output files fixed via un-flattening', - ) - args = parser.parse_args() - input_dir = pathlib.Path(args.input_dir) - output_dir = pathlib.Path(args.output_dir) - if not input_dir.is_dir(): - print( - f"input directory '{input_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if not output_dir.is_dir(): - print( - f"output directory '{output_dir}' does not exist or is not a directory", - file=sys.stderr, - ) - sys.exit(-1) - - if os.listdir(output_dir): - print( - f"output directory '{output_dir}' is not empty", - file=sys.stderr, - ) - sys.exit(-1) - - fix_files(input_dir, output_dir) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py deleted file mode 100644 index 0e0b1e55d45f..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/setup.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import io -import os - -import setuptools # type: ignore - -package_root = os.path.abspath(os.path.dirname(__file__)) - -name = 'google-ai-generativelanguage' - - -description = "Google Ai Generativelanguage API client library" - -version = {} -with open(os.path.join(package_root, 'google/ai/generativelanguage/gapic_version.py')) as fp: - exec(fp.read(), version) -version = version["__version__"] - -if version[0] == "0": - release_status = "Development Status :: 4 - Beta" -else: - release_status = "Development Status :: 5 - Production/Stable" - -dependencies = [ - "google-api-core[grpc] >= 1.34.0, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*", - "proto-plus >= 1.22.0, <2.0.0dev", - "proto-plus >= 1.22.2, <2.0.0dev; python_version>='3.11'", - "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", -] -url = "https://github.com/googleapis/python-ai-generativelanguage" - -package_root = os.path.abspath(os.path.dirname(__file__)) - -readme_filename = os.path.join(package_root, "README.rst") -with io.open(readme_filename, encoding="utf-8") as readme_file: - readme = readme_file.read() - -packages = [ - package - for package in setuptools.PEP420PackageFinder.find() - if package.startswith("google") -] - -namespaces = ["google", "google.ai"] - -setuptools.setup( - name=name, - version=version, - description=description, - long_description=readme, - author="Google LLC", - author_email="googleapis-packages@google.com", - license="Apache 2.0", - url=url, - classifiers=[ - release_status, - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Operating System :: OS Independent", - "Topic :: Internet", - ], - platforms="Posix; MacOS X; Windows", - packages=packages, - python_requires=">=3.7", - namespace_packages=namespaces, - install_requires=dependencies, - include_package_data=True, - zip_safe=False, -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.10.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.11.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.12.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt deleted file mode 100644 index 6c44adfea7ee..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.7.txt +++ /dev/null @@ -1,9 +0,0 @@ -# This constraints file is used to check that lower bounds -# are correct in setup.py -# List all library dependencies and extras in this file. -# Pin the version to the lower bound. -# e.g., if setup.py has "google-cloud-foo >= 1.14.0, < 2.0.0dev", -# Then this file should have google-cloud-foo==1.14.0 -google-api-core==1.34.0 -proto-plus==1.22.0 -protobuf==3.19.5 diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.8.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/testing/constraints-3.9.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py deleted file mode 100644 index fa35eaf42fd5..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_discuss_service.py +++ /dev/null @@ -1,2205 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -# try/except added for compatibility with python < 3.8 -try: - from unittest import mock - from unittest.mock import AsyncMock # pragma: NO COVER -except ImportError: # pragma: NO COVER - import mock - -import grpc -from grpc.experimental import aio -from collections.abc import Iterable -from google.protobuf import json_format -import json -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format - -from google.ai.generativelanguage_v1beta2.services.discuss_service import DiscussServiceAsyncClient -from google.ai.generativelanguage_v1beta2.services.discuss_service import DiscussServiceClient -from google.ai.generativelanguage_v1beta2.services.discuss_service import transports -from google.ai.generativelanguage_v1beta2.types import citation -from google.ai.generativelanguage_v1beta2.types import discuss_service -from google.ai.generativelanguage_v1beta2.types import safety -from google.api_core import client_options -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import path_template -from google.auth import credentials as ga_credentials -from google.auth.exceptions import MutualTLSChannelError -from google.oauth2 import service_account -import google.auth - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert DiscussServiceClient._get_default_mtls_endpoint(None) is None - assert DiscussServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - - -@pytest.mark.parametrize("client_class,transport_name", [ - (DiscussServiceClient, "grpc"), - (DiscussServiceAsyncClient, "grpc_asyncio"), - (DiscussServiceClient, "rest"), -]) -def test_discuss_service_client_from_service_account_info(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: - factory.return_value = creds - info = {"valid": True} - client = client_class.from_service_account_info(info, transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' - ) - - -@pytest.mark.parametrize("transport_class,transport_name", [ - (transports.DiscussServiceGrpcTransport, "grpc"), - (transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.DiscussServiceRestTransport, "rest"), -]) -def test_discuss_service_client_service_account_always_use_jwt(transport_class, transport_name): - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: - creds = service_account.Credentials(None, None, None) - transport = transport_class(credentials=creds, always_use_jwt_access=True) - use_jwt.assert_called_once_with(True) - - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: - creds = service_account.Credentials(None, None, None) - transport = transport_class(credentials=creds, always_use_jwt_access=False) - use_jwt.assert_not_called() - - -@pytest.mark.parametrize("client_class,transport_name", [ - (DiscussServiceClient, "grpc"), - (DiscussServiceAsyncClient, "grpc_asyncio"), - (DiscussServiceClient, "rest"), -]) -def test_discuss_service_client_from_service_account_file(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' - ) - - -def test_discuss_service_client_get_transport_class(): - transport = DiscussServiceClient.get_transport_class() - available_transports = [ - transports.DiscussServiceGrpcTransport, - transports.DiscussServiceRestTransport, - ] - assert transport in available_transports - - transport = DiscussServiceClient.get_transport_class("grpc") - assert transport == transports.DiscussServiceGrpcTransport - - -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), -]) -@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) -@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) -def test_discuss_service_client_client_options(client_class, transport_class, transport_name): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name, client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class(transport=transport_name) - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): - with pytest.raises(ValueError): - client = client_class(transport=transport_name) - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - # Check the case api_endpoint is provided - options = client_options.ClientOptions(api_audience="https://language.googleapis.com") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience="https://language.googleapis.com" - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "true"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "false"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "true"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "false"), -]) -@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) -@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_discuss_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - - if use_client_cert_env == "false": - expected_client_cert_source = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_client_cert_source = client_cert_source_callback - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - client_cert_source_for_mtls=expected_client_cert_source, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): - if use_client_cert_env == "false": - expected_host = client.DEFAULT_ENDPOINT - expected_client_cert_source = None - else: - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_client_cert_source = client_cert_source_callback - - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - client_cert_source_for_mtls=expected_client_cert_source, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - -@pytest.mark.parametrize("client_class", [ - DiscussServiceClient, DiscussServiceAsyncClient -]) -@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) -@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) -def test_discuss_service_client_get_mtls_endpoint_and_cert_source(client_class): - mock_client_cert_source = mock.Mock() - - # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) - assert api_endpoint == mock_api_endpoint - assert cert_source == mock_client_cert_source - - # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): - mock_client_cert_source = mock.Mock() - mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) - assert api_endpoint == mock_api_endpoint - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT - assert cert_source == mock_client_cert_source - - -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), -]) -def test_discuss_service_client_client_options_scopes(client_class, transport_class, transport_name): - # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", None), -]) -def test_discuss_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): - # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - -def test_discuss_service_client_client_options_from_dict(): - with mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceGrpcTransport.__init__') as grpc_transport: - grpc_transport.return_value = None - client = DiscussServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), -]) -def test_discuss_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): - # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # test that the credentials from file are saved and used as the credentials. - with mock.patch.object( - google.auth, "load_credentials_from_file", autospec=True - ) as load_creds, mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel" - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - file_creds = ga_credentials.AnonymousCredentials() - load_creds.return_value = (file_creds, None) - adc.return_value = (creds, None) - client = client_class(client_options=options, transport=transport_name) - create_channel.assert_called_with( - "generativelanguage.googleapis.com:443", - credentials=file_creds, - credentials_file=None, - quota_project_id=None, - default_scopes=( -), - scopes=None, - default_host="generativelanguage.googleapis.com", - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize("request_type", [ - discuss_service.GenerateMessageRequest, - dict, -]) -def test_generate_message(request_type, transport: str = 'grpc'): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = discuss_service.GenerateMessageResponse( - ) - response = client.generate_message(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == discuss_service.GenerateMessageRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, discuss_service.GenerateMessageResponse) - - -def test_generate_message_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - client.generate_message() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == discuss_service.GenerateMessageRequest() - -@pytest.mark.asyncio -async def test_generate_message_async(transport: str = 'grpc_asyncio', request_type=discuss_service.GenerateMessageRequest): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse( - )) - response = await client.generate_message(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == discuss_service.GenerateMessageRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, discuss_service.GenerateMessageResponse) - - -@pytest.mark.asyncio -async def test_generate_message_async_from_dict(): - await test_generate_message_async(request_type=dict) - - -def test_generate_message_field_headers(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = discuss_service.GenerateMessageRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - call.return_value = discuss_service.GenerateMessageResponse() - client.generate_message(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -@pytest.mark.asyncio -async def test_generate_message_field_headers_async(): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = discuss_service.GenerateMessageRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) - await client.generate_message(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -def test_generate_message_flattened(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = discuss_service.GenerateMessageResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.generate_message( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - temperature=0.1198, - candidate_count=1573, - top_p=0.546, - top_k=541, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') - assert arg == mock_val - assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) - arg = args[0].candidate_count - mock_val = 1573 - assert arg == mock_val - assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) - arg = args[0].top_k - mock_val = 541 - assert arg == mock_val - - -def test_generate_message_flattened_error(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.generate_message( - discuss_service.GenerateMessageRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - temperature=0.1198, - candidate_count=1573, - top_p=0.546, - top_k=541, - ) - -@pytest.mark.asyncio -async def test_generate_message_flattened_async(): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = discuss_service.GenerateMessageResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.generate_message( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - temperature=0.1198, - candidate_count=1573, - top_p=0.546, - top_k=541, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') - assert arg == mock_val - assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) - arg = args[0].candidate_count - mock_val = 1573 - assert arg == mock_val - assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) - arg = args[0].top_k - mock_val = 541 - assert arg == mock_val - -@pytest.mark.asyncio -async def test_generate_message_flattened_error_async(): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.generate_message( - discuss_service.GenerateMessageRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - temperature=0.1198, - candidate_count=1573, - top_p=0.546, - top_k=541, - ) - - -@pytest.mark.parametrize("request_type", [ - discuss_service.CountMessageTokensRequest, - dict, -]) -def test_count_message_tokens(request_type, transport: str = 'grpc'): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = discuss_service.CountMessageTokensResponse( - token_count=1193, - ) - response = client.count_message_tokens(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == discuss_service.CountMessageTokensRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, discuss_service.CountMessageTokensResponse) - assert response.token_count == 1193 - - -def test_count_message_tokens_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - client.count_message_tokens() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == discuss_service.CountMessageTokensRequest() - -@pytest.mark.asyncio -async def test_count_message_tokens_async(transport: str = 'grpc_asyncio', request_type=discuss_service.CountMessageTokensRequest): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse( - token_count=1193, - )) - response = await client.count_message_tokens(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == discuss_service.CountMessageTokensRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, discuss_service.CountMessageTokensResponse) - assert response.token_count == 1193 - - -@pytest.mark.asyncio -async def test_count_message_tokens_async_from_dict(): - await test_count_message_tokens_async(request_type=dict) - - -def test_count_message_tokens_field_headers(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = discuss_service.CountMessageTokensRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - call.return_value = discuss_service.CountMessageTokensResponse() - client.count_message_tokens(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -@pytest.mark.asyncio -async def test_count_message_tokens_field_headers_async(): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = discuss_service.CountMessageTokensRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) - await client.count_message_tokens(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -def test_count_message_tokens_flattened(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = discuss_service.CountMessageTokensResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.count_message_tokens( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') - assert arg == mock_val - - -def test_count_message_tokens_flattened_error(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.count_message_tokens( - discuss_service.CountMessageTokensRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - ) - -@pytest.mark.asyncio -async def test_count_message_tokens_flattened_async(): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = discuss_service.CountMessageTokensResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.count_message_tokens( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') - assert arg == mock_val - -@pytest.mark.asyncio -async def test_count_message_tokens_flattened_error_async(): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.count_message_tokens( - discuss_service.CountMessageTokensRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - ) - - -@pytest.mark.parametrize("request_type", [ - discuss_service.GenerateMessageRequest, - dict, -]) -def test_generate_message_rest(request_type): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = discuss_service.GenerateMessageResponse( - ) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - response = client.generate_message(request) - - # Establish that the response is the type that we expect. - assert isinstance(response, discuss_service.GenerateMessageResponse) - - -def test_generate_message_rest_required_fields(request_type=discuss_service.GenerateMessageRequest): - transport_class = transports.DiscussServiceRestTransport - - request_init = {} - request_init["model"] = "" - request = request_type(**request_init) - pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) - - # verify fields with default values are dropped - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with default values are now present - - jsonified_request["model"] = 'model_value' - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with non-default values are left alone - assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' - - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest', - ) - request = request_type(**request_init) - - # Designate an appropriate value for the returned response. - return_value = discuss_service.GenerateMessageResponse() - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: - # We need to mock transcode() because providing default values - # for required fields will fail the real version if the http_options - # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: - # A uri without fields and an empty body will force all the - # request fields to show up in the query_params. - pb_request = request_type.pb(request) - transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, - } - transcode_result['body'] = pb_request - transcode.return_value = transcode_result - - response_value = Response() - response_value.status_code = 200 - - pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - response = client.generate_message(request) - - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] - assert expected_params == actual_params - - -def test_generate_message_rest_unset_required_fields(): - transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) - - unset_fields = transport.generate_message._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) - - -@pytest.mark.parametrize("null_interceptor", [True, False]) -def test_generate_message_rest_interceptors(null_interceptor): - transport = transports.DiscussServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), - ) - client = DiscussServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "post_generate_message") as post, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_generate_message") as pre: - pre.assert_not_called() - post.assert_not_called() - pb_message = discuss_service.GenerateMessageRequest.pb(discuss_service.GenerateMessageRequest()) - transcode.return_value = { - "method": "post", - "uri": "my_uri", - "body": pb_message, - "query_params": pb_message, - } - - req.return_value = Response() - req.return_value.status_code = 200 - req.return_value.request = PreparedRequest() - req.return_value._content = discuss_service.GenerateMessageResponse.to_json(discuss_service.GenerateMessageResponse()) - - request = discuss_service.GenerateMessageRequest() - metadata =[ - ("key", "val"), - ("cephalopod", "squid"), - ] - pre.return_value = request, metadata - post.return_value = discuss_service.GenerateMessageResponse() - - client.generate_message(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) - - pre.assert_called_once() - post.assert_called_once() - - -def test_generate_message_rest_bad_request(transport: str = 'rest', request_type=discuss_service.GenerateMessageRequest): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 400 - response_value.request = Request() - req.return_value = response_value - client.generate_message(request) - - -def test_generate_message_rest_flattened(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = discuss_service.GenerateMessageResponse() - - # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} - - # get truthy value for each flattened field - mock_args = dict( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - temperature=0.1198, - candidate_count=1573, - top_p=0.546, - top_k=541, - ) - mock_args.update(sample_request) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - client.generate_message(**mock_args) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta2/{model=models/*}:generateMessage" % client.transport._host, args[1]) - - -def test_generate_message_rest_flattened_error(transport: str = 'rest'): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.generate_message( - discuss_service.GenerateMessageRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - temperature=0.1198, - candidate_count=1573, - top_p=0.546, - top_k=541, - ) - - -def test_generate_message_rest_error(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' - ) - - -@pytest.mark.parametrize("request_type", [ - discuss_service.CountMessageTokensRequest, - dict, -]) -def test_count_message_tokens_rest(request_type): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = discuss_service.CountMessageTokensResponse( - token_count=1193, - ) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - response = client.count_message_tokens(request) - - # Establish that the response is the type that we expect. - assert isinstance(response, discuss_service.CountMessageTokensResponse) - assert response.token_count == 1193 - - -def test_count_message_tokens_rest_required_fields(request_type=discuss_service.CountMessageTokensRequest): - transport_class = transports.DiscussServiceRestTransport - - request_init = {} - request_init["model"] = "" - request = request_type(**request_init) - pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) - - # verify fields with default values are dropped - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with default values are now present - - jsonified_request["model"] = 'model_value' - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with non-default values are left alone - assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' - - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest', - ) - request = request_type(**request_init) - - # Designate an appropriate value for the returned response. - return_value = discuss_service.CountMessageTokensResponse() - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: - # We need to mock transcode() because providing default values - # for required fields will fail the real version if the http_options - # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: - # A uri without fields and an empty body will force all the - # request fields to show up in the query_params. - pb_request = request_type.pb(request) - transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, - } - transcode_result['body'] = pb_request - transcode.return_value = transcode_result - - response_value = Response() - response_value.status_code = 200 - - pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - response = client.count_message_tokens(request) - - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] - assert expected_params == actual_params - - -def test_count_message_tokens_rest_unset_required_fields(): - transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) - - unset_fields = transport.count_message_tokens._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) - - -@pytest.mark.parametrize("null_interceptor", [True, False]) -def test_count_message_tokens_rest_interceptors(null_interceptor): - transport = transports.DiscussServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), - ) - client = DiscussServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "post_count_message_tokens") as post, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_count_message_tokens") as pre: - pre.assert_not_called() - post.assert_not_called() - pb_message = discuss_service.CountMessageTokensRequest.pb(discuss_service.CountMessageTokensRequest()) - transcode.return_value = { - "method": "post", - "uri": "my_uri", - "body": pb_message, - "query_params": pb_message, - } - - req.return_value = Response() - req.return_value.status_code = 200 - req.return_value.request = PreparedRequest() - req.return_value._content = discuss_service.CountMessageTokensResponse.to_json(discuss_service.CountMessageTokensResponse()) - - request = discuss_service.CountMessageTokensRequest() - metadata =[ - ("key", "val"), - ("cephalopod", "squid"), - ] - pre.return_value = request, metadata - post.return_value = discuss_service.CountMessageTokensResponse() - - client.count_message_tokens(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) - - pre.assert_called_once() - post.assert_called_once() - - -def test_count_message_tokens_rest_bad_request(transport: str = 'rest', request_type=discuss_service.CountMessageTokensRequest): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 400 - response_value.request = Request() - req.return_value = response_value - client.count_message_tokens(request) - - -def test_count_message_tokens_rest_flattened(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = discuss_service.CountMessageTokensResponse() - - # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} - - # get truthy value for each flattened field - mock_args = dict( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - ) - mock_args.update(sample_request) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - client.count_message_tokens(**mock_args) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta2/{model=models/*}:countMessageTokens" % client.transport._host, args[1]) - - -def test_count_message_tokens_rest_flattened_error(transport: str = 'rest'): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.count_message_tokens( - discuss_service.CountMessageTokensRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), - ) - - -def test_count_message_tokens_rest_error(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.DiscussServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.DiscussServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DiscussServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide an api_key and a transport instance. - transport = transports.DiscussServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = DiscussServiceClient( - client_options=options, - transport=transport, - ) - - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = DiscussServiceClient( - client_options=options, - credentials=ga_credentials.AnonymousCredentials() - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.DiscussServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DiscussServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.DiscussServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - client = DiscussServiceClient(transport=transport) - assert client.transport is transport - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.DiscussServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.DiscussServiceGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - -@pytest.mark.parametrize("transport_class", [ - transports.DiscussServiceGrpcTransport, - transports.DiscussServiceGrpcAsyncIOTransport, - transports.DiscussServiceRestTransport, -]) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(google.auth, 'default') as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "rest", -]) -def test_transport_kind(transport_name): - transport = DiscussServiceClient.get_transport_class(transport_name)( - credentials=ga_credentials.AnonymousCredentials(), - ) - assert transport.kind == transport_name - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.DiscussServiceGrpcTransport, - ) - -def test_discuss_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(core_exceptions.DuplicateCredentialArgs): - transport = transports.DiscussServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json" - ) - - -def test_discuss_service_base_transport(): - # Instantiate the base transport. - with mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceTransport.__init__') as Transport: - Transport.return_value = None - transport = transports.DiscussServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - 'generate_message', - 'count_message_tokens', - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - with pytest.raises(NotImplementedError): - transport.close() - - # Catch all for all remaining methods and properties - remainder = [ - 'kind', - ] - for r in remainder: - with pytest.raises(NotImplementedError): - getattr(transport, r)() - - -def test_discuss_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: - Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.DiscussServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", - ) - load_creds.assert_called_once_with("credentials.json", - scopes=None, - default_scopes=( -), - quota_project_id="octopus", - ) - - -def test_discuss_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta2.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: - Transport.return_value = None - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.DiscussServiceTransport() - adc.assert_called_once() - - -def test_discuss_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - DiscussServiceClient() - adc.assert_called_once_with( - scopes=None, - default_scopes=( -), - quota_project_id=None, - ) - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.DiscussServiceGrpcTransport, - transports.DiscussServiceGrpcAsyncIOTransport, - ], -) -def test_discuss_service_transport_auth_adc(transport_class): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - adc.assert_called_once_with( - scopes=["1", "2"], - default_scopes=(), - quota_project_id="octopus", - ) - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.DiscussServiceGrpcTransport, - transports.DiscussServiceGrpcAsyncIOTransport, - transports.DiscussServiceRestTransport, - ], -) -def test_discuss_service_transport_auth_gdch_credentials(transport_class): - host = 'https://language.com' - api_audience_tests = [None, 'https://language2.com'] - api_audience_expect = [host, 'https://language2.com'] - for t, e in zip(api_audience_tests, api_audience_expect): - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - gdch_mock = mock.MagicMock() - type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) - adc.return_value = (gdch_mock, None) - transport_class(host=host, api_audience=t) - gdch_mock.with_gdch_audience.assert_called_once_with( - e - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.DiscussServiceGrpcTransport, grpc_helpers), - (transports.DiscussServiceGrpcAsyncIOTransport, grpc_helpers_async) - ], -) -def test_discuss_service_transport_create_channel(transport_class, grpc_helpers): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class( - quota_project_id="octopus", - scopes=["1", "2"] - ) - - create_channel.assert_called_with( - "generativelanguage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - default_scopes=( -), - scopes=["1", "2"], - default_host="generativelanguage.googleapis.com", - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) -def test_discuss_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): - cred = ga_credentials.AnonymousCredentials() - - # Check ssl_channel_credentials is used if provided. - with mock.patch.object(transport_class, "create_channel") as mock_create_channel: - mock_ssl_channel_creds = mock.Mock() - transport_class( - host="squid.clam.whelk", - credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds - ) - mock_create_channel.assert_called_once_with( - "squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_channel_creds, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls - # is used. - with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): - with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: - transport_class( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback - ) - expected_cert, expected_key = client_cert_source_callback() - mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key - ) - -def test_discuss_service_http_transport_client_cert_source_for_mtls(): - cred = ga_credentials.AnonymousCredentials() - with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: - transports.DiscussServiceRestTransport ( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback - ) - mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) - - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) -def test_discuss_service_host_no_port(transport_name): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), - transport=transport_name, - ) - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com' - ) - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) -def test_discuss_service_host_with_port(transport_name): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), - transport=transport_name, - ) - assert client.transport._host == ( - 'generativelanguage.googleapis.com:8000' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com:8000' - ) - -@pytest.mark.parametrize("transport_name", [ - "rest", -]) -def test_discuss_service_client_transport_session_collision(transport_name): - creds1 = ga_credentials.AnonymousCredentials() - creds2 = ga_credentials.AnonymousCredentials() - client1 = DiscussServiceClient( - credentials=creds1, - transport=transport_name, - ) - client2 = DiscussServiceClient( - credentials=creds2, - transport=transport_name, - ) - session1 = client1.transport.generate_message._session - session2 = client2.transport.generate_message._session - assert session1 != session2 - session1 = client1.transport.count_message_tokens._session - session2 = client2.transport.count_message_tokens._session - assert session1 != session2 -def test_discuss_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) - - # Check that channel is used if provided. - transport = transports.DiscussServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - assert transport._ssl_channel_credentials == None - - -def test_discuss_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) - - # Check that channel is used if provided. - transport = transports.DiscussServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - assert transport._ssl_channel_credentials == None - - -# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are -# removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) -def test_discuss_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = ga_credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(google.auth, 'default') as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - assert transport.grpc_channel == mock_grpc_channel - assert transport._ssl_channel_credentials == mock_ssl_cred - - -# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are -# removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) -def test_discuss_service_transport_channel_mtls_with_adc( - transport_class -): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_model_path(): - model = "squid" - expected = "models/{model}".format(model=model, ) - actual = DiscussServiceClient.model_path(model) - assert expected == actual - - -def test_parse_model_path(): - expected = { - "model": "clam", - } - path = DiscussServiceClient.model_path(**expected) - - # Check that the path construction is reversible. - actual = DiscussServiceClient.parse_model_path(path) - assert expected == actual - -def test_common_billing_account_path(): - billing_account = "whelk" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) - actual = DiscussServiceClient.common_billing_account_path(billing_account) - assert expected == actual - - -def test_parse_common_billing_account_path(): - expected = { - "billing_account": "octopus", - } - path = DiscussServiceClient.common_billing_account_path(**expected) - - # Check that the path construction is reversible. - actual = DiscussServiceClient.parse_common_billing_account_path(path) - assert expected == actual - -def test_common_folder_path(): - folder = "oyster" - expected = "folders/{folder}".format(folder=folder, ) - actual = DiscussServiceClient.common_folder_path(folder) - assert expected == actual - - -def test_parse_common_folder_path(): - expected = { - "folder": "nudibranch", - } - path = DiscussServiceClient.common_folder_path(**expected) - - # Check that the path construction is reversible. - actual = DiscussServiceClient.parse_common_folder_path(path) - assert expected == actual - -def test_common_organization_path(): - organization = "cuttlefish" - expected = "organizations/{organization}".format(organization=organization, ) - actual = DiscussServiceClient.common_organization_path(organization) - assert expected == actual - - -def test_parse_common_organization_path(): - expected = { - "organization": "mussel", - } - path = DiscussServiceClient.common_organization_path(**expected) - - # Check that the path construction is reversible. - actual = DiscussServiceClient.parse_common_organization_path(path) - assert expected == actual - -def test_common_project_path(): - project = "winkle" - expected = "projects/{project}".format(project=project, ) - actual = DiscussServiceClient.common_project_path(project) - assert expected == actual - - -def test_parse_common_project_path(): - expected = { - "project": "nautilus", - } - path = DiscussServiceClient.common_project_path(**expected) - - # Check that the path construction is reversible. - actual = DiscussServiceClient.parse_common_project_path(path) - assert expected == actual - -def test_common_location_path(): - project = "scallop" - location = "abalone" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) - actual = DiscussServiceClient.common_location_path(project, location) - assert expected == actual - - -def test_parse_common_location_path(): - expected = { - "project": "squid", - "location": "clam", - } - path = DiscussServiceClient.common_location_path(**expected) - - # Check that the path construction is reversible. - actual = DiscussServiceClient.parse_common_location_path(path) - assert expected == actual - - -def test_client_with_default_client_info(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: - transport_class = DiscussServiceClient.get_transport_class() - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials(), - client_info=client_info, - ) - prep.assert_called_once_with(client_info) - -@pytest.mark.asyncio -async def test_transport_close_async(): - client = DiscussServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) - with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: - async with client: - close.assert_not_called() - close.assert_called_once() - - -def test_transport_close(): - transports = { - "rest": "_session", - "grpc": "_grpc_channel", - } - - for transport, close_name in transports.items(): - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport - ) - with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: - with client: - close.assert_not_called() - close.assert_called_once() - -def test_client_ctx(): - transports = [ - 'rest', - 'grpc', - ] - for transport in transports: - client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport - ) - # Test client calls underlying transport. - with mock.patch.object(type(client.transport), "close") as close: - close.assert_not_called() - with client: - pass - close.assert_called() - -@pytest.mark.parametrize("client_class,transport_class", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport), -]) -def test_api_key_credentials(client_class, transport_class): - with mock.patch.object( - google.auth._default, "get_api_key_credentials", create=True - ) as get_api_key_credentials: - mock_cred = mock.Mock() - get_api_key_credentials.return_value = mock_cred - options = client_options.ClientOptions() - options.api_key = "api_key" - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=mock_cred, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py deleted file mode 100644 index c7a1ee1f30f8..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_model_service.py +++ /dev/null @@ -1,2319 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -# try/except added for compatibility with python < 3.8 -try: - from unittest import mock - from unittest.mock import AsyncMock # pragma: NO COVER -except ImportError: # pragma: NO COVER - import mock - -import grpc -from grpc.experimental import aio -from collections.abc import Iterable -from google.protobuf import json_format -import json -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format - -from google.ai.generativelanguage_v1beta2.services.model_service import ModelServiceAsyncClient -from google.ai.generativelanguage_v1beta2.services.model_service import ModelServiceClient -from google.ai.generativelanguage_v1beta2.services.model_service import pagers -from google.ai.generativelanguage_v1beta2.services.model_service import transports -from google.ai.generativelanguage_v1beta2.types import model -from google.ai.generativelanguage_v1beta2.types import model_service -from google.api_core import client_options -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import path_template -from google.auth import credentials as ga_credentials -from google.auth.exceptions import MutualTLSChannelError -from google.oauth2 import service_account -import google.auth - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - - -@pytest.mark.parametrize("client_class,transport_name", [ - (ModelServiceClient, "grpc"), - (ModelServiceAsyncClient, "grpc_asyncio"), - (ModelServiceClient, "rest"), -]) -def test_model_service_client_from_service_account_info(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: - factory.return_value = creds - info = {"valid": True} - client = client_class.from_service_account_info(info, transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' - ) - - -@pytest.mark.parametrize("transport_class,transport_name", [ - (transports.ModelServiceGrpcTransport, "grpc"), - (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.ModelServiceRestTransport, "rest"), -]) -def test_model_service_client_service_account_always_use_jwt(transport_class, transport_name): - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: - creds = service_account.Credentials(None, None, None) - transport = transport_class(credentials=creds, always_use_jwt_access=True) - use_jwt.assert_called_once_with(True) - - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: - creds = service_account.Credentials(None, None, None) - transport = transport_class(credentials=creds, always_use_jwt_access=False) - use_jwt.assert_not_called() - - -@pytest.mark.parametrize("client_class,transport_name", [ - (ModelServiceClient, "grpc"), - (ModelServiceAsyncClient, "grpc_asyncio"), - (ModelServiceClient, "rest"), -]) -def test_model_service_client_from_service_account_file(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' - ) - - -def test_model_service_client_get_transport_class(): - transport = ModelServiceClient.get_transport_class() - available_transports = [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceRestTransport, - ] - assert transport in available_transports - - transport = ModelServiceClient.get_transport_class("grpc") - assert transport == transports.ModelServiceGrpcTransport - - -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -def test_model_service_client_client_options(client_class, transport_class, transport_name): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name, client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class(transport=transport_name) - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): - with pytest.raises(ValueError): - client = client_class(transport=transport_name) - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - # Check the case api_endpoint is provided - options = client_options.ClientOptions(api_audience="https://language.googleapis.com") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience="https://language.googleapis.com" - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "true"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "false"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - - if use_client_cert_env == "false": - expected_client_cert_source = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_client_cert_source = client_cert_source_callback - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - client_cert_source_for_mtls=expected_client_cert_source, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): - if use_client_cert_env == "false": - expected_host = client.DEFAULT_ENDPOINT - expected_client_cert_source = None - else: - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_client_cert_source = client_cert_source_callback - - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - client_cert_source_for_mtls=expected_client_cert_source, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - -@pytest.mark.parametrize("client_class", [ - ModelServiceClient, ModelServiceAsyncClient -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): - mock_client_cert_source = mock.Mock() - - # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) - assert api_endpoint == mock_api_endpoint - assert cert_source == mock_client_cert_source - - # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): - mock_client_cert_source = mock.Mock() - mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) - assert api_endpoint == mock_api_endpoint - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT - assert cert_source == mock_client_cert_source - - -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), -]) -def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): - # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest", None), -]) -def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): - # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - -def test_model_service_client_client_options_from_dict(): - with mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: - grpc_transport.return_value = None - client = ModelServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), -]) -def test_model_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): - # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # test that the credentials from file are saved and used as the credentials. - with mock.patch.object( - google.auth, "load_credentials_from_file", autospec=True - ) as load_creds, mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel" - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - file_creds = ga_credentials.AnonymousCredentials() - load_creds.return_value = (file_creds, None) - adc.return_value = (creds, None) - client = client_class(client_options=options, transport=transport_name) - create_channel.assert_called_with( - "generativelanguage.googleapis.com:443", - credentials=file_creds, - credentials_file=None, - quota_project_id=None, - default_scopes=( -), - scopes=None, - default_host="generativelanguage.googleapis.com", - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize("request_type", [ - model_service.GetModelRequest, - dict, -]) -def test_get_model(request_type, transport: str = 'grpc'): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model( - name='name_value', - base_model_id='base_model_id_value', - version='version_value', - display_name='display_name_value', - description='description_value', - input_token_limit=1838, - output_token_limit=1967, - supported_generation_methods=['supported_generation_methods_value'], - temperature=0.1198, - top_p=0.546, - top_k=541, - ) - response = client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == model_service.GetModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, model.Model) - assert response.name == 'name_value' - assert response.base_model_id == 'base_model_id_value' - assert response.version == 'version_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' - assert response.input_token_limit == 1838 - assert response.output_token_limit == 1967 - assert response.supported_generation_methods == ['supported_generation_methods_value'] - assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) - assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) - assert response.top_k == 541 - - -def test_get_model_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: - client.get_model() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == model_service.GetModelRequest() - -@pytest.mark.asyncio -async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model.Model( - name='name_value', - base_model_id='base_model_id_value', - version='version_value', - display_name='display_name_value', - description='description_value', - input_token_limit=1838, - output_token_limit=1967, - supported_generation_methods=['supported_generation_methods_value'], - temperature=0.1198, - top_p=0.546, - top_k=541, - )) - response = await client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == model_service.GetModelRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, model.Model) - assert response.name == 'name_value' - assert response.base_model_id == 'base_model_id_value' - assert response.version == 'version_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' - assert response.input_token_limit == 1838 - assert response.output_token_limit == 1967 - assert response.supported_generation_methods == ['supported_generation_methods_value'] - assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) - assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) - assert response.top_k == 541 - - -@pytest.mark.asyncio -async def test_get_model_async_from_dict(): - await test_get_model_async(request_type=dict) - - -def test_get_model_field_headers(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelRequest() - - request.name = 'name_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: - call.return_value = model.Model() - client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] - - -@pytest.mark.asyncio -async def test_get_model_field_headers_async(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = model_service.GetModelRequest() - - request.name = 'name_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) - await client.get_model(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] - - -def test_get_model_flattened(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.get_model( - name='name_value', - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = 'name_value' - assert arg == mock_val - - -def test_get_model_flattened_error(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model( - model_service.GetModelRequest(), - name='name_value', - ) - -@pytest.mark.asyncio -async def test_get_model_flattened_async(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = model.Model() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.get_model( - name='name_value', - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].name - mock_val = 'name_value' - assert arg == mock_val - -@pytest.mark.asyncio -async def test_get_model_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.get_model( - model_service.GetModelRequest(), - name='name_value', - ) - - -@pytest.mark.parametrize("request_type", [ - model_service.ListModelsRequest, - dict, -]) -def test_list_models(request_type, transport: str = 'grpc'): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', - ) - response = client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == model_service.ListModelsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' - - -def test_list_models_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - client.list_models() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == model_service.ListModelsRequest() - -@pytest.mark.asyncio -async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( - next_page_token='next_page_token_value', - )) - response = await client.list_models(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == model_service.ListModelsRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == 'next_page_token_value' - - -@pytest.mark.asyncio -async def test_list_models_async_from_dict(): - await test_list_models_async(request_type=dict) - - -def test_list_models_flattened(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.list_models( - page_size=951, - page_token='page_token_value', - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].page_size - mock_val = 951 - assert arg == mock_val - arg = args[0].page_token - mock_val = 'page_token_value' - assert arg == mock_val - - -def test_list_models_flattened_error(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_models( - model_service.ListModelsRequest(), - page_size=951, - page_token='page_token_value', - ) - -@pytest.mark.asyncio -async def test_list_models_flattened_async(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = model_service.ListModelsResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.list_models( - page_size=951, - page_token='page_token_value', - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].page_size - mock_val = 951 - assert arg == mock_val - arg = args[0].page_token - mock_val = 'page_token_value' - assert arg == mock_val - -@pytest.mark.asyncio -async def test_list_models_flattened_error_async(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.list_models( - model_service.ListModelsRequest(), - page_size=951, - page_token='page_token_value', - ) - - -def test_list_models_pager(transport_name: str = "grpc"): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials, - transport=transport_name, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], - ), - RuntimeError, - ) - - metadata = () - pager = client.list_models(request={}) - - assert pager._metadata == metadata - - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) -def test_list_models_pages(transport_name: str = "grpc"): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials, - transport=transport_name, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], - ), - RuntimeError, - ) - pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): - assert page_.raw_page.next_page_token == token - -@pytest.mark.asyncio -async def test_list_models_async_pager(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], - ), - RuntimeError, - ) - async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == 'abc' - responses = [] - async for response in async_pager: # pragma: no branch - responses.append(response) - - assert len(responses) == 6 - assert all(isinstance(i, model.Model) - for i in responses) - - -@pytest.mark.asyncio -async def test_list_models_async_pages(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials, - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: - # Set the response to a series of pages. - call.side_effect = ( - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], - ), - RuntimeError, - ) - pages = [] - # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` - # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch - await client.list_models(request={}) - ).pages: - pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): - assert page_.raw_page.next_page_token == token - - -@pytest.mark.parametrize("request_type", [ - model_service.GetModelRequest, - dict, -]) -def test_get_model_rest(request_type): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {'name': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = model.Model( - name='name_value', - base_model_id='base_model_id_value', - version='version_value', - display_name='display_name_value', - description='description_value', - input_token_limit=1838, - output_token_limit=1967, - supported_generation_methods=['supported_generation_methods_value'], - temperature=0.1198, - top_p=0.546, - top_k=541, - ) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = model.Model.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - response = client.get_model(request) - - # Establish that the response is the type that we expect. - assert isinstance(response, model.Model) - assert response.name == 'name_value' - assert response.base_model_id == 'base_model_id_value' - assert response.version == 'version_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' - assert response.input_token_limit == 1838 - assert response.output_token_limit == 1967 - assert response.supported_generation_methods == ['supported_generation_methods_value'] - assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) - assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) - assert response.top_k == 541 - - -def test_get_model_rest_required_fields(request_type=model_service.GetModelRequest): - transport_class = transports.ModelServiceRestTransport - - request_init = {} - request_init["name"] = "" - request = request_type(**request_init) - pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) - - # verify fields with default values are dropped - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with default values are now present - - jsonified_request["name"] = 'name_value' - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with non-default values are left alone - assert "name" in jsonified_request - assert jsonified_request["name"] == 'name_value' - - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest', - ) - request = request_type(**request_init) - - # Designate an appropriate value for the returned response. - return_value = model.Model() - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: - # We need to mock transcode() because providing default values - # for required fields will fail the real version if the http_options - # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: - # A uri without fields and an empty body will force all the - # request fields to show up in the query_params. - pb_request = request_type.pb(request) - transcode_result = { - 'uri': 'v1/sample_method', - 'method': "get", - 'query_params': pb_request, - } - transcode.return_value = transcode_result - - response_value = Response() - response_value.status_code = 200 - - pb_return_value = model.Model.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - response = client.get_model(request) - - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] - assert expected_params == actual_params - - -def test_get_model_rest_unset_required_fields(): - transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) - - unset_fields = transport.get_model._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name", ))) - - -@pytest.mark.parametrize("null_interceptor", [True, False]) -def test_get_model_rest_interceptors(null_interceptor): - transport = transports.ModelServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) - client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_get_model") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_get_model") as pre: - pre.assert_not_called() - post.assert_not_called() - pb_message = model_service.GetModelRequest.pb(model_service.GetModelRequest()) - transcode.return_value = { - "method": "post", - "uri": "my_uri", - "body": pb_message, - "query_params": pb_message, - } - - req.return_value = Response() - req.return_value.status_code = 200 - req.return_value.request = PreparedRequest() - req.return_value._content = model.Model.to_json(model.Model()) - - request = model_service.GetModelRequest() - metadata =[ - ("key", "val"), - ("cephalopod", "squid"), - ] - pre.return_value = request, metadata - post.return_value = model.Model() - - client.get_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) - - pre.assert_called_once() - post.assert_called_once() - - -def test_get_model_rest_bad_request(transport: str = 'rest', request_type=model_service.GetModelRequest): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # send a request that will satisfy transcoding - request_init = {'name': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 400 - response_value.request = Request() - req.return_value = response_value - client.get_model(request) - - -def test_get_model_rest_flattened(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = model.Model() - - # get arguments that satisfy an http rule for this method - sample_request = {'name': 'models/sample1'} - - # get truthy value for each flattened field - mock_args = dict( - name='name_value', - ) - mock_args.update(sample_request) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = model.Model.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - client.get_model(**mock_args) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta2/{name=models/*}" % client.transport._host, args[1]) - - -def test_get_model_rest_flattened_error(transport: str = 'rest'): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.get_model( - model_service.GetModelRequest(), - name='name_value', - ) - - -def test_get_model_rest_error(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' - ) - - -@pytest.mark.parametrize("request_type", [ - model_service.ListModelsRequest, - dict, -]) -def test_list_models_rest(request_type): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', - ) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = model_service.ListModelsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - response = client.list_models(request) - - # Establish that the response is the type that we expect. - assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' - - -@pytest.mark.parametrize("null_interceptor", [True, False]) -def test_list_models_rest_interceptors(null_interceptor): - transport = transports.ModelServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) - client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_list_models") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_list_models") as pre: - pre.assert_not_called() - post.assert_not_called() - pb_message = model_service.ListModelsRequest.pb(model_service.ListModelsRequest()) - transcode.return_value = { - "method": "post", - "uri": "my_uri", - "body": pb_message, - "query_params": pb_message, - } - - req.return_value = Response() - req.return_value.status_code = 200 - req.return_value.request = PreparedRequest() - req.return_value._content = model_service.ListModelsResponse.to_json(model_service.ListModelsResponse()) - - request = model_service.ListModelsRequest() - metadata =[ - ("key", "val"), - ("cephalopod", "squid"), - ] - pre.return_value = request, metadata - post.return_value = model_service.ListModelsResponse() - - client.list_models(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) - - pre.assert_called_once() - post.assert_called_once() - - -def test_list_models_rest_bad_request(transport: str = 'rest', request_type=model_service.ListModelsRequest): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # send a request that will satisfy transcoding - request_init = {} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 400 - response_value.request = Request() - req.return_value = response_value - client.list_models(request) - - -def test_list_models_rest_flattened(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = model_service.ListModelsResponse() - - # get arguments that satisfy an http rule for this method - sample_request = {} - - # get truthy value for each flattened field - mock_args = dict( - page_size=951, - page_token='page_token_value', - ) - mock_args.update(sample_request) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = model_service.ListModelsResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - client.list_models(**mock_args) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta2/models" % client.transport._host, args[1]) - - -def test_list_models_rest_flattened_error(transport: str = 'rest'): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.list_models( - model_service.ListModelsRequest(), - page_size=951, - page_token='page_token_value', - ) - - -def test_list_models_rest_pager(transport: str = 'rest'): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: - # TODO(kbandes): remove this mock unless there's a good reason for it. - #with mock.patch.object(path_template, 'transcode') as transcode: - # Set the response as a series of pages - response = ( - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - model.Model(), - ], - next_page_token='abc', - ), - model_service.ListModelsResponse( - models=[], - next_page_token='def', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - ], - next_page_token='ghi', - ), - model_service.ListModelsResponse( - models=[ - model.Model(), - model.Model(), - ], - ), - ) - # Two responses for two calls - response = response + response - - # Wrap the values into proper Response objs - response = tuple(model_service.ListModelsResponse.to_json(x) for x in response) - return_values = tuple(Response() for i in response) - for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode('UTF-8') - return_val.status_code = 200 - req.side_effect = return_values - - sample_request = {} - - pager = client.list_models(request=sample_request) - - results = list(pager) - assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) - - pages = list(client.list_models(request=sample_request).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): - assert page_.raw_page.next_page_token == token - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide an api_key and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = ModelServiceClient( - client_options=options, - transport=transport, - ) - - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = ModelServiceClient( - client_options=options, - credentials=ga_credentials.AnonymousCredentials() - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = ModelServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - client = ModelServiceClient(transport=transport) - assert client.transport is transport - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.ModelServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.ModelServiceGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - -@pytest.mark.parametrize("transport_class", [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, - transports.ModelServiceRestTransport, -]) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(google.auth, 'default') as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "rest", -]) -def test_transport_kind(transport_name): - transport = ModelServiceClient.get_transport_class(transport_name)( - credentials=ga_credentials.AnonymousCredentials(), - ) - assert transport.kind == transport_name - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.ModelServiceGrpcTransport, - ) - -def test_model_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(core_exceptions.DuplicateCredentialArgs): - transport = transports.ModelServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json" - ) - - -def test_model_service_base_transport(): - # Instantiate the base transport. - with mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceTransport.__init__') as Transport: - Transport.return_value = None - transport = transports.ModelServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - 'get_model', - 'list_models', - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - with pytest.raises(NotImplementedError): - transport.close() - - # Catch all for all remaining methods and properties - remainder = [ - 'kind', - ] - for r in remainder: - with pytest.raises(NotImplementedError): - getattr(transport, r)() - - -def test_model_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: - Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.ModelServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", - ) - load_creds.assert_called_once_with("credentials.json", - scopes=None, - default_scopes=( -), - quota_project_id="octopus", - ) - - -def test_model_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta2.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: - Transport.return_value = None - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.ModelServiceTransport() - adc.assert_called_once() - - -def test_model_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - ModelServiceClient() - adc.assert_called_once_with( - scopes=None, - default_scopes=( -), - quota_project_id=None, - ) - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, - ], -) -def test_model_service_transport_auth_adc(transport_class): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - adc.assert_called_once_with( - scopes=["1", "2"], - default_scopes=(), - quota_project_id="octopus", - ) - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, - transports.ModelServiceRestTransport, - ], -) -def test_model_service_transport_auth_gdch_credentials(transport_class): - host = 'https://language.com' - api_audience_tests = [None, 'https://language2.com'] - api_audience_expect = [host, 'https://language2.com'] - for t, e in zip(api_audience_tests, api_audience_expect): - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - gdch_mock = mock.MagicMock() - type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) - adc.return_value = (gdch_mock, None) - transport_class(host=host, api_audience=t) - gdch_mock.with_gdch_audience.assert_called_once_with( - e - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.ModelServiceGrpcTransport, grpc_helpers), - (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async) - ], -) -def test_model_service_transport_create_channel(transport_class, grpc_helpers): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class( - quota_project_id="octopus", - scopes=["1", "2"] - ) - - create_channel.assert_called_with( - "generativelanguage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - default_scopes=( -), - scopes=["1", "2"], - default_host="generativelanguage.googleapis.com", - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): - cred = ga_credentials.AnonymousCredentials() - - # Check ssl_channel_credentials is used if provided. - with mock.patch.object(transport_class, "create_channel") as mock_create_channel: - mock_ssl_channel_creds = mock.Mock() - transport_class( - host="squid.clam.whelk", - credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds - ) - mock_create_channel.assert_called_once_with( - "squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_channel_creds, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls - # is used. - with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): - with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: - transport_class( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback - ) - expected_cert, expected_key = client_cert_source_callback() - mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key - ) - -def test_model_service_http_transport_client_cert_source_for_mtls(): - cred = ga_credentials.AnonymousCredentials() - with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: - transports.ModelServiceRestTransport ( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback - ) - mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) - - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) -def test_model_service_host_no_port(transport_name): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), - transport=transport_name, - ) - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com' - ) - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) -def test_model_service_host_with_port(transport_name): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), - transport=transport_name, - ) - assert client.transport._host == ( - 'generativelanguage.googleapis.com:8000' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com:8000' - ) - -@pytest.mark.parametrize("transport_name", [ - "rest", -]) -def test_model_service_client_transport_session_collision(transport_name): - creds1 = ga_credentials.AnonymousCredentials() - creds2 = ga_credentials.AnonymousCredentials() - client1 = ModelServiceClient( - credentials=creds1, - transport=transport_name, - ) - client2 = ModelServiceClient( - credentials=creds2, - transport=transport_name, - ) - session1 = client1.transport.get_model._session - session2 = client2.transport.get_model._session - assert session1 != session2 - session1 = client1.transport.list_models._session - session2 = client2.transport.list_models._session - assert session1 != session2 -def test_model_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) - - # Check that channel is used if provided. - transport = transports.ModelServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - assert transport._ssl_channel_credentials == None - - -def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) - - # Check that channel is used if provided. - transport = transports.ModelServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - assert transport._ssl_channel_credentials == None - - -# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are -# removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = ga_credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(google.auth, 'default') as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - assert transport.grpc_channel == mock_grpc_channel - assert transport._ssl_channel_credentials == mock_ssl_cred - - -# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are -# removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_adc( - transport_class -): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_model_path(): - model = "squid" - expected = "models/{model}".format(model=model, ) - actual = ModelServiceClient.model_path(model) - assert expected == actual - - -def test_parse_model_path(): - expected = { - "model": "clam", - } - path = ModelServiceClient.model_path(**expected) - - # Check that the path construction is reversible. - actual = ModelServiceClient.parse_model_path(path) - assert expected == actual - -def test_common_billing_account_path(): - billing_account = "whelk" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) - actual = ModelServiceClient.common_billing_account_path(billing_account) - assert expected == actual - - -def test_parse_common_billing_account_path(): - expected = { - "billing_account": "octopus", - } - path = ModelServiceClient.common_billing_account_path(**expected) - - # Check that the path construction is reversible. - actual = ModelServiceClient.parse_common_billing_account_path(path) - assert expected == actual - -def test_common_folder_path(): - folder = "oyster" - expected = "folders/{folder}".format(folder=folder, ) - actual = ModelServiceClient.common_folder_path(folder) - assert expected == actual - - -def test_parse_common_folder_path(): - expected = { - "folder": "nudibranch", - } - path = ModelServiceClient.common_folder_path(**expected) - - # Check that the path construction is reversible. - actual = ModelServiceClient.parse_common_folder_path(path) - assert expected == actual - -def test_common_organization_path(): - organization = "cuttlefish" - expected = "organizations/{organization}".format(organization=organization, ) - actual = ModelServiceClient.common_organization_path(organization) - assert expected == actual - - -def test_parse_common_organization_path(): - expected = { - "organization": "mussel", - } - path = ModelServiceClient.common_organization_path(**expected) - - # Check that the path construction is reversible. - actual = ModelServiceClient.parse_common_organization_path(path) - assert expected == actual - -def test_common_project_path(): - project = "winkle" - expected = "projects/{project}".format(project=project, ) - actual = ModelServiceClient.common_project_path(project) - assert expected == actual - - -def test_parse_common_project_path(): - expected = { - "project": "nautilus", - } - path = ModelServiceClient.common_project_path(**expected) - - # Check that the path construction is reversible. - actual = ModelServiceClient.parse_common_project_path(path) - assert expected == actual - -def test_common_location_path(): - project = "scallop" - location = "abalone" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) - actual = ModelServiceClient.common_location_path(project, location) - assert expected == actual - - -def test_parse_common_location_path(): - expected = { - "project": "squid", - "location": "clam", - } - path = ModelServiceClient.common_location_path(**expected) - - # Check that the path construction is reversible. - actual = ModelServiceClient.parse_common_location_path(path) - assert expected == actual - - -def test_client_with_default_client_info(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: - transport_class = ModelServiceClient.get_transport_class() - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials(), - client_info=client_info, - ) - prep.assert_called_once_with(client_info) - -@pytest.mark.asyncio -async def test_transport_close_async(): - client = ModelServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) - with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: - async with client: - close.assert_not_called() - close.assert_called_once() - - -def test_transport_close(): - transports = { - "rest": "_session", - "grpc": "_grpc_channel", - } - - for transport, close_name in transports.items(): - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport - ) - with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: - with client: - close.assert_not_called() - close.assert_called_once() - -def test_client_ctx(): - transports = [ - 'rest', - 'grpc', - ] - for transport in transports: - client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport - ) - # Test client calls underlying transport. - with mock.patch.object(type(client.transport), "close") as close: - close.assert_not_called() - with client: - pass - close.assert_called() - -@pytest.mark.parametrize("client_class,transport_class", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), -]) -def test_api_key_credentials(client_class, transport_class): - with mock.patch.object( - google.auth._default, "get_api_key_credentials", create=True - ) as get_api_key_credentials: - mock_cred = mock.Mock() - get_api_key_credentials.return_value = mock_cred - options = client_options.ClientOptions() - options.api_key = "api_key" - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=mock_cred, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py b/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py deleted file mode 100644 index 2fbd8b3036c2..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/tests/unit/gapic/generativelanguage_v1beta2/test_text_service.py +++ /dev/null @@ -1,2214 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -# try/except added for compatibility with python < 3.8 -try: - from unittest import mock - from unittest.mock import AsyncMock # pragma: NO COVER -except ImportError: # pragma: NO COVER - import mock - -import grpc -from grpc.experimental import aio -from collections.abc import Iterable -from google.protobuf import json_format -import json -import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format - -from google.ai.generativelanguage_v1beta2.services.text_service import TextServiceAsyncClient -from google.ai.generativelanguage_v1beta2.services.text_service import TextServiceClient -from google.ai.generativelanguage_v1beta2.services.text_service import transports -from google.ai.generativelanguage_v1beta2.types import safety -from google.ai.generativelanguage_v1beta2.types import text_service -from google.api_core import client_options -from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import path_template -from google.auth import credentials as ga_credentials -from google.auth.exceptions import MutualTLSChannelError -from google.oauth2 import service_account -import google.auth - - -def client_cert_source_callback(): - return b"cert bytes", b"key bytes" - - -# If default endpoint is localhost, then default mtls endpoint will be the same. -# This method modifies the default endpoint so the client can produce a different -# mtls endpoint for endpoint testing purposes. -def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT - - -def test__get_default_mtls_endpoint(): - api_endpoint = "example.googleapis.com" - api_mtls_endpoint = "example.mtls.googleapis.com" - sandbox_endpoint = "example.sandbox.googleapis.com" - sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" - non_googleapi = "api.example.com" - - assert TextServiceClient._get_default_mtls_endpoint(None) is None - assert TextServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert TextServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert TextServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert TextServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert TextServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - - -@pytest.mark.parametrize("client_class,transport_name", [ - (TextServiceClient, "grpc"), - (TextServiceAsyncClient, "grpc_asyncio"), - (TextServiceClient, "rest"), -]) -def test_text_service_client_from_service_account_info(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: - factory.return_value = creds - info = {"valid": True} - client = client_class.from_service_account_info(info, transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' - ) - - -@pytest.mark.parametrize("transport_class,transport_name", [ - (transports.TextServiceGrpcTransport, "grpc"), - (transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.TextServiceRestTransport, "rest"), -]) -def test_text_service_client_service_account_always_use_jwt(transport_class, transport_name): - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: - creds = service_account.Credentials(None, None, None) - transport = transport_class(credentials=creds, always_use_jwt_access=True) - use_jwt.assert_called_once_with(True) - - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: - creds = service_account.Credentials(None, None, None) - transport = transport_class(credentials=creds, always_use_jwt_access=False) - use_jwt.assert_not_called() - - -@pytest.mark.parametrize("client_class,transport_name", [ - (TextServiceClient, "grpc"), - (TextServiceAsyncClient, "grpc_asyncio"), - (TextServiceClient, "rest"), -]) -def test_text_service_client_from_service_account_file(client_class, transport_name): - creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: - factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) - assert client.transport._credentials == creds - assert isinstance(client, client_class) - - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' - ) - - -def test_text_service_client_get_transport_class(): - transport = TextServiceClient.get_transport_class() - available_transports = [ - transports.TextServiceGrpcTransport, - transports.TextServiceRestTransport, - ] - assert transport in available_transports - - transport = TextServiceClient.get_transport_class("grpc") - assert transport == transports.TextServiceGrpcTransport - - -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (TextServiceClient, transports.TextServiceRestTransport, "rest"), -]) -@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) -@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) -def test_text_service_client_client_options(client_class, transport_class, transport_name): - # Check that if channel is provided we won't create a new one. - with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ) - client = client_class(transport=transport) - gtc.assert_not_called() - - # Check that if channel is provided via str we will create a new one. - with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: - client = client_class(transport=transport_name) - gtc.assert_called() - - # Check the case api_endpoint is provided. - options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name, client_options=options) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is - # "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_MTLS_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has - # unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): - with pytest.raises(MutualTLSChannelError): - client = client_class(transport=transport_name) - - # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): - with pytest.raises(ValueError): - client = client_class(transport=transport_name) - - # Check the case quota_project_id is provided - options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id="octopus", - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - # Check the case api_endpoint is provided - options = client_options.ClientOptions(api_audience="https://language.googleapis.com") - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience="https://language.googleapis.com" - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "true"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "false"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - (TextServiceClient, transports.TextServiceRestTransport, "rest", "true"), - (TextServiceClient, transports.TextServiceRestTransport, "rest", "false"), -]) -@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) -@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) -@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_text_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): - # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default - # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. - - # Check the case client_cert_source is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - - if use_client_cert_env == "false": - expected_client_cert_source = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_client_cert_source = client_cert_source_callback - expected_host = client.DEFAULT_MTLS_ENDPOINT - - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - client_cert_source_for_mtls=expected_client_cert_source, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case ADC client cert is provided. Whether client cert is used depends on - # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): - if use_client_cert_env == "false": - expected_host = client.DEFAULT_ENDPOINT - expected_client_cert_source = None - else: - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_client_cert_source = client_cert_source_callback - - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - client_cert_source_for_mtls=expected_client_cert_source, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): - patched.return_value = None - client = client_class(transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - -@pytest.mark.parametrize("client_class", [ - TextServiceClient, TextServiceAsyncClient -]) -@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) -@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) -def test_text_service_client_get_mtls_endpoint_and_cert_source(client_class): - mock_client_cert_source = mock.Mock() - - # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) - assert api_endpoint == mock_api_endpoint - assert cert_source == mock_client_cert_source - - # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): - mock_client_cert_source = mock.Mock() - mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) - assert api_endpoint == mock_api_endpoint - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_ENDPOINT - assert cert_source is None - - # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() - assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT - assert cert_source == mock_client_cert_source - - -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (TextServiceClient, transports.TextServiceRestTransport, "rest"), -]) -def test_text_service_client_client_options_scopes(client_class, transport_class, transport_name): - # Check the case scopes are provided. - options = client_options.ClientOptions( - scopes=["1", "2"], - ) - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=["1", "2"], - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), - (TextServiceClient, transports.TextServiceRestTransport, "rest", None), -]) -def test_text_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): - # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - -def test_text_service_client_client_options_from_dict(): - with mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceGrpcTransport.__init__') as grpc_transport: - grpc_transport.return_value = None - client = TextServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) - grpc_transport.assert_called_once_with( - credentials=None, - credentials_file=None, - host="squid.clam.whelk", - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), -]) -def test_text_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): - # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) - - with mock.patch.object(transport_class, '__init__') as patched: - patched.return_value = None - client = client_class(client_options=options, transport=transport_name) - patched.assert_called_once_with( - credentials=None, - credentials_file="credentials.json", - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) - - # test that the credentials from file are saved and used as the credentials. - with mock.patch.object( - google.auth, "load_credentials_from_file", autospec=True - ) as load_creds, mock.patch.object( - google.auth, "default", autospec=True - ) as adc, mock.patch.object( - grpc_helpers, "create_channel" - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - file_creds = ga_credentials.AnonymousCredentials() - load_creds.return_value = (file_creds, None) - adc.return_value = (creds, None) - client = client_class(client_options=options, transport=transport_name) - create_channel.assert_called_with( - "generativelanguage.googleapis.com:443", - credentials=file_creds, - credentials_file=None, - quota_project_id=None, - default_scopes=( -), - scopes=None, - default_host="generativelanguage.googleapis.com", - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize("request_type", [ - text_service.GenerateTextRequest, - dict, -]) -def test_generate_text(request_type, transport: str = 'grpc'): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = text_service.GenerateTextResponse( - ) - response = client.generate_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == text_service.GenerateTextRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, text_service.GenerateTextResponse) - - -def test_generate_text_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - client.generate_text() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == text_service.GenerateTextRequest() - -@pytest.mark.asyncio -async def test_generate_text_async(transport: str = 'grpc_asyncio', request_type=text_service.GenerateTextRequest): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse( - )) - response = await client.generate_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == text_service.GenerateTextRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, text_service.GenerateTextResponse) - - -@pytest.mark.asyncio -async def test_generate_text_async_from_dict(): - await test_generate_text_async(request_type=dict) - - -def test_generate_text_field_headers(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = text_service.GenerateTextRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - call.return_value = text_service.GenerateTextResponse() - client.generate_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -@pytest.mark.asyncio -async def test_generate_text_field_headers_async(): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = text_service.GenerateTextRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) - await client.generate_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -def test_generate_text_flattened(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = text_service.GenerateTextResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.generate_text( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), - temperature=0.1198, - candidate_count=1573, - max_output_tokens=1865, - top_p=0.546, - top_k=541, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].prompt - mock_val = text_service.TextPrompt(text='text_value') - assert arg == mock_val - assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) - arg = args[0].candidate_count - mock_val = 1573 - assert arg == mock_val - arg = args[0].max_output_tokens - mock_val = 1865 - assert arg == mock_val - assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) - arg = args[0].top_k - mock_val = 541 - assert arg == mock_val - - -def test_generate_text_flattened_error(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.generate_text( - text_service.GenerateTextRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), - temperature=0.1198, - candidate_count=1573, - max_output_tokens=1865, - top_p=0.546, - top_k=541, - ) - -@pytest.mark.asyncio -async def test_generate_text_flattened_async(): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = text_service.GenerateTextResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.generate_text( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), - temperature=0.1198, - candidate_count=1573, - max_output_tokens=1865, - top_p=0.546, - top_k=541, - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].prompt - mock_val = text_service.TextPrompt(text='text_value') - assert arg == mock_val - assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) - arg = args[0].candidate_count - mock_val = 1573 - assert arg == mock_val - arg = args[0].max_output_tokens - mock_val = 1865 - assert arg == mock_val - assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) - arg = args[0].top_k - mock_val = 541 - assert arg == mock_val - -@pytest.mark.asyncio -async def test_generate_text_flattened_error_async(): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.generate_text( - text_service.GenerateTextRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), - temperature=0.1198, - candidate_count=1573, - max_output_tokens=1865, - top_p=0.546, - top_k=541, - ) - - -@pytest.mark.parametrize("request_type", [ - text_service.EmbedTextRequest, - dict, -]) -def test_embed_text(request_type, transport: str = 'grpc'): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = text_service.EmbedTextResponse( - ) - response = client.embed_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == text_service.EmbedTextRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, text_service.EmbedTextResponse) - - -def test_embed_text_empty_call(): - # This test is a coverage failsafe to make sure that totally empty calls, - # i.e. request == None and no flattened fields passed, work. - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - client.embed_text() - call.assert_called() - _, args, _ = call.mock_calls[0] - assert args[0] == text_service.EmbedTextRequest() - -@pytest.mark.asyncio -async def test_embed_text_async(transport: str = 'grpc_asyncio', request_type=text_service.EmbedTextRequest): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Everything is optional in proto3 as far as the runtime is concerned, - # and we are mocking out the actual API, so just send an empty request. - request = request_type() - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse( - )) - response = await client.embed_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == text_service.EmbedTextRequest() - - # Establish that the response is the type that we expect. - assert isinstance(response, text_service.EmbedTextResponse) - - -@pytest.mark.asyncio -async def test_embed_text_async_from_dict(): - await test_embed_text_async(request_type=dict) - - -def test_embed_text_field_headers(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = text_service.EmbedTextRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - call.return_value = text_service.EmbedTextResponse() - client.embed_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -@pytest.mark.asyncio -async def test_embed_text_field_headers_async(): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Any value that is part of the HTTP/1.1 URI should be sent as - # a field header. Set these to a non-empty value. - request = text_service.EmbedTextRequest() - - request.model = 'model_value' - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) - await client.embed_text(request) - - # Establish that the underlying gRPC stub method was called. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - assert args[0] == request - - # Establish that the field header was sent. - _, _, kw = call.mock_calls[0] - assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] - - -def test_embed_text_flattened(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = text_service.EmbedTextResponse() - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - client.embed_text( - model='model_value', - text='text_value', - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) == 1 - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].text - mock_val = 'text_value' - assert arg == mock_val - - -def test_embed_text_flattened_error(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.embed_text( - text_service.EmbedTextRequest(), - model='model_value', - text='text_value', - ) - -@pytest.mark.asyncio -async def test_embed_text_flattened_async(): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - # Designate an appropriate return value for the call. - call.return_value = text_service.EmbedTextResponse() - - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) - # Call the method with a truthy value for each flattened field, - # using the keyword arguments to the method. - response = await client.embed_text( - model='model_value', - text='text_value', - ) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(call.mock_calls) - _, args, _ = call.mock_calls[0] - arg = args[0].model - mock_val = 'model_value' - assert arg == mock_val - arg = args[0].text - mock_val = 'text_value' - assert arg == mock_val - -@pytest.mark.asyncio -async def test_embed_text_flattened_error_async(): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - await client.embed_text( - text_service.EmbedTextRequest(), - model='model_value', - text='text_value', - ) - - -@pytest.mark.parametrize("request_type", [ - text_service.GenerateTextRequest, - dict, -]) -def test_generate_text_rest(request_type): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = text_service.GenerateTextResponse( - ) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = text_service.GenerateTextResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - response = client.generate_text(request) - - # Establish that the response is the type that we expect. - assert isinstance(response, text_service.GenerateTextResponse) - - -def test_generate_text_rest_required_fields(request_type=text_service.GenerateTextRequest): - transport_class = transports.TextServiceRestTransport - - request_init = {} - request_init["model"] = "" - request = request_type(**request_init) - pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) - - # verify fields with default values are dropped - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with default values are now present - - jsonified_request["model"] = 'model_value' - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with non-default values are left alone - assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' - - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest', - ) - request = request_type(**request_init) - - # Designate an appropriate value for the returned response. - return_value = text_service.GenerateTextResponse() - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: - # We need to mock transcode() because providing default values - # for required fields will fail the real version if the http_options - # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: - # A uri without fields and an empty body will force all the - # request fields to show up in the query_params. - pb_request = request_type.pb(request) - transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, - } - transcode_result['body'] = pb_request - transcode.return_value = transcode_result - - response_value = Response() - response_value.status_code = 200 - - pb_return_value = text_service.GenerateTextResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - response = client.generate_text(request) - - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] - assert expected_params == actual_params - - -def test_generate_text_rest_unset_required_fields(): - transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) - - unset_fields = transport.generate_text._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) - - -@pytest.mark.parametrize("null_interceptor", [True, False]) -def test_generate_text_rest_interceptors(null_interceptor): - transport = transports.TextServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), - ) - client = TextServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.TextServiceRestInterceptor, "post_generate_text") as post, \ - mock.patch.object(transports.TextServiceRestInterceptor, "pre_generate_text") as pre: - pre.assert_not_called() - post.assert_not_called() - pb_message = text_service.GenerateTextRequest.pb(text_service.GenerateTextRequest()) - transcode.return_value = { - "method": "post", - "uri": "my_uri", - "body": pb_message, - "query_params": pb_message, - } - - req.return_value = Response() - req.return_value.status_code = 200 - req.return_value.request = PreparedRequest() - req.return_value._content = text_service.GenerateTextResponse.to_json(text_service.GenerateTextResponse()) - - request = text_service.GenerateTextRequest() - metadata =[ - ("key", "val"), - ("cephalopod", "squid"), - ] - pre.return_value = request, metadata - post.return_value = text_service.GenerateTextResponse() - - client.generate_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) - - pre.assert_called_once() - post.assert_called_once() - - -def test_generate_text_rest_bad_request(transport: str = 'rest', request_type=text_service.GenerateTextRequest): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 400 - response_value.request = Request() - req.return_value = response_value - client.generate_text(request) - - -def test_generate_text_rest_flattened(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = text_service.GenerateTextResponse() - - # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} - - # get truthy value for each flattened field - mock_args = dict( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), - temperature=0.1198, - candidate_count=1573, - max_output_tokens=1865, - top_p=0.546, - top_k=541, - ) - mock_args.update(sample_request) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = text_service.GenerateTextResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - client.generate_text(**mock_args) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta2/{model=models/*}:generateText" % client.transport._host, args[1]) - - -def test_generate_text_rest_flattened_error(transport: str = 'rest'): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.generate_text( - text_service.GenerateTextRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), - temperature=0.1198, - candidate_count=1573, - max_output_tokens=1865, - top_p=0.546, - top_k=541, - ) - - -def test_generate_text_rest_error(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' - ) - - -@pytest.mark.parametrize("request_type", [ - text_service.EmbedTextRequest, - dict, -]) -def test_embed_text_rest(request_type): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = text_service.EmbedTextResponse( - ) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = text_service.EmbedTextResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - response = client.embed_text(request) - - # Establish that the response is the type that we expect. - assert isinstance(response, text_service.EmbedTextResponse) - - -def test_embed_text_rest_required_fields(request_type=text_service.EmbedTextRequest): - transport_class = transports.TextServiceRestTransport - - request_init = {} - request_init["model"] = "" - request_init["text"] = "" - request = request_type(**request_init) - pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) - - # verify fields with default values are dropped - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with default values are now present - - jsonified_request["model"] = 'model_value' - jsonified_request["text"] = 'text_value' - - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) - jsonified_request.update(unset_fields) - - # verify required fields with non-default values are left alone - assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' - assert "text" in jsonified_request - assert jsonified_request["text"] == 'text_value' - - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest', - ) - request = request_type(**request_init) - - # Designate an appropriate value for the returned response. - return_value = text_service.EmbedTextResponse() - # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: - # We need to mock transcode() because providing default values - # for required fields will fail the real version if the http_options - # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: - # A uri without fields and an empty body will force all the - # request fields to show up in the query_params. - pb_request = request_type.pb(request) - transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, - } - transcode_result['body'] = pb_request - transcode.return_value = transcode_result - - response_value = Response() - response_value.status_code = 200 - - pb_return_value = text_service.EmbedTextResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - response = client.embed_text(request) - - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] - assert expected_params == actual_params - - -def test_embed_text_rest_unset_required_fields(): - transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) - - unset_fields = transport.embed_text._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "text", ))) - - -@pytest.mark.parametrize("null_interceptor", [True, False]) -def test_embed_text_rest_interceptors(null_interceptor): - transport = transports.TextServiceRestTransport( - credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), - ) - client = TextServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.TextServiceRestInterceptor, "post_embed_text") as post, \ - mock.patch.object(transports.TextServiceRestInterceptor, "pre_embed_text") as pre: - pre.assert_not_called() - post.assert_not_called() - pb_message = text_service.EmbedTextRequest.pb(text_service.EmbedTextRequest()) - transcode.return_value = { - "method": "post", - "uri": "my_uri", - "body": pb_message, - "query_params": pb_message, - } - - req.return_value = Response() - req.return_value.status_code = 200 - req.return_value.request = PreparedRequest() - req.return_value._content = text_service.EmbedTextResponse.to_json(text_service.EmbedTextResponse()) - - request = text_service.EmbedTextRequest() - metadata =[ - ("key", "val"), - ("cephalopod", "squid"), - ] - pre.return_value = request, metadata - post.return_value = text_service.EmbedTextResponse() - - client.embed_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) - - pre.assert_called_once() - post.assert_called_once() - - -def test_embed_text_rest_bad_request(transport: str = 'rest', request_type=text_service.EmbedTextRequest): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} - request = request_type(**request_init) - - # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 400 - response_value.request = Request() - req.return_value = response_value - client.embed_text(request) - - -def test_embed_text_rest_flattened(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="rest", - ) - - # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: - # Designate an appropriate value for the returned response. - return_value = text_service.EmbedTextResponse() - - # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} - - # get truthy value for each flattened field - mock_args = dict( - model='model_value', - text='text_value', - ) - mock_args.update(sample_request) - - # Wrap the value into a proper Response obj - response_value = Response() - response_value.status_code = 200 - pb_return_value = text_service.EmbedTextResponse.pb(return_value) - json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') - req.return_value = response_value - - client.embed_text(**mock_args) - - # Establish that the underlying call was made with the expected - # request object values. - assert len(req.mock_calls) == 1 - _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta2/{model=models/*}:embedText" % client.transport._host, args[1]) - - -def test_embed_text_rest_flattened_error(transport: str = 'rest'): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # Attempting to call a method with both a request object and flattened - # fields is an error. - with pytest.raises(ValueError): - client.embed_text( - text_service.EmbedTextRequest(), - model='model_value', - text='text_value', - ) - - -def test_embed_text_rest_error(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' - ) - - -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.TextServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - - # It is an error to provide a credentials file and a transport instance. - transport = transports.TextServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = TextServiceClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) - - # It is an error to provide an api_key and a transport instance. - transport = transports.TextServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = TextServiceClient( - client_options=options, - transport=transport, - ) - - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = TextServiceClient( - client_options=options, - credentials=ga_credentials.AnonymousCredentials() - ) - - # It is an error to provide scopes and a transport instance. - transport = transports.TextServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = TextServiceClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, - ) - - -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.TextServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - client = TextServiceClient(transport=transport) - assert client.transport is transport - -def test_transport_get_channel(): - # A client may be instantiated with a custom transport instance. - transport = transports.TextServiceGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - - transport = transports.TextServiceGrpcAsyncIOTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - channel = transport.grpc_channel - assert channel - -@pytest.mark.parametrize("transport_class", [ - transports.TextServiceGrpcTransport, - transports.TextServiceGrpcAsyncIOTransport, - transports.TextServiceRestTransport, -]) -def test_transport_adc(transport_class): - # Test default credentials are used if not provided. - with mock.patch.object(google.auth, 'default') as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class() - adc.assert_called_once() - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "rest", -]) -def test_transport_kind(transport_name): - transport = TextServiceClient.get_transport_class(transport_name)( - credentials=ga_credentials.AnonymousCredentials(), - ) - assert transport.kind == transport_name - -def test_transport_grpc_default(): - # A client should use the gRPC transport by default. - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - ) - assert isinstance( - client.transport, - transports.TextServiceGrpcTransport, - ) - -def test_text_service_base_transport_error(): - # Passing both a credentials object and credentials_file should raise an error - with pytest.raises(core_exceptions.DuplicateCredentialArgs): - transport = transports.TextServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json" - ) - - -def test_text_service_base_transport(): - # Instantiate the base transport. - with mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceTransport.__init__') as Transport: - Transport.return_value = None - transport = transports.TextServiceTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - - # Every method on the transport should just blindly - # raise NotImplementedError. - methods = ( - 'generate_text', - 'embed_text', - ) - for method in methods: - with pytest.raises(NotImplementedError): - getattr(transport, method)(request=object()) - - with pytest.raises(NotImplementedError): - transport.close() - - # Catch all for all remaining methods and properties - remainder = [ - 'kind', - ] - for r in remainder: - with pytest.raises(NotImplementedError): - getattr(transport, r)() - - -def test_text_service_base_transport_with_credentials_file(): - # Instantiate the base transport with a credentials file - with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: - Transport.return_value = None - load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.TextServiceTransport( - credentials_file="credentials.json", - quota_project_id="octopus", - ) - load_creds.assert_called_once_with("credentials.json", - scopes=None, - default_scopes=( -), - quota_project_id="octopus", - ) - - -def test_text_service_base_transport_with_adc(): - # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta2.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: - Transport.return_value = None - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport = transports.TextServiceTransport() - adc.assert_called_once() - - -def test_text_service_auth_adc(): - # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - TextServiceClient() - adc.assert_called_once_with( - scopes=None, - default_scopes=( -), - quota_project_id=None, - ) - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.TextServiceGrpcTransport, - transports.TextServiceGrpcAsyncIOTransport, - ], -) -def test_text_service_transport_auth_adc(transport_class): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - adc.return_value = (ga_credentials.AnonymousCredentials(), None) - transport_class(quota_project_id="octopus", scopes=["1", "2"]) - adc.assert_called_once_with( - scopes=["1", "2"], - default_scopes=(), - quota_project_id="octopus", - ) - - -@pytest.mark.parametrize( - "transport_class", - [ - transports.TextServiceGrpcTransport, - transports.TextServiceGrpcAsyncIOTransport, - transports.TextServiceRestTransport, - ], -) -def test_text_service_transport_auth_gdch_credentials(transport_class): - host = 'https://language.com' - api_audience_tests = [None, 'https://language2.com'] - api_audience_expect = [host, 'https://language2.com'] - for t, e in zip(api_audience_tests, api_audience_expect): - with mock.patch.object(google.auth, 'default', autospec=True) as adc: - gdch_mock = mock.MagicMock() - type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) - adc.return_value = (gdch_mock, None) - transport_class(host=host, api_audience=t) - gdch_mock.with_gdch_audience.assert_called_once_with( - e - ) - - -@pytest.mark.parametrize( - "transport_class,grpc_helpers", - [ - (transports.TextServiceGrpcTransport, grpc_helpers), - (transports.TextServiceGrpcAsyncIOTransport, grpc_helpers_async) - ], -) -def test_text_service_transport_create_channel(transport_class, grpc_helpers): - # If credentials and host are not provided, the transport class should use - # ADC credentials. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( - grpc_helpers, "create_channel", autospec=True - ) as create_channel: - creds = ga_credentials.AnonymousCredentials() - adc.return_value = (creds, None) - transport_class( - quota_project_id="octopus", - scopes=["1", "2"] - ) - - create_channel.assert_called_with( - "generativelanguage.googleapis.com:443", - credentials=creds, - credentials_file=None, - quota_project_id="octopus", - default_scopes=( -), - scopes=["1", "2"], - default_host="generativelanguage.googleapis.com", - ssl_credentials=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - -@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) -def test_text_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): - cred = ga_credentials.AnonymousCredentials() - - # Check ssl_channel_credentials is used if provided. - with mock.patch.object(transport_class, "create_channel") as mock_create_channel: - mock_ssl_channel_creds = mock.Mock() - transport_class( - host="squid.clam.whelk", - credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds - ) - mock_create_channel.assert_called_once_with( - "squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_channel_creds, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - - # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls - # is used. - with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): - with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: - transport_class( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback - ) - expected_cert, expected_key = client_cert_source_callback() - mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key - ) - -def test_text_service_http_transport_client_cert_source_for_mtls(): - cred = ga_credentials.AnonymousCredentials() - with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: - transports.TextServiceRestTransport ( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback - ) - mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) - - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) -def test_text_service_host_no_port(transport_name): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), - transport=transport_name, - ) - assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com' - ) - -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) -def test_text_service_host_with_port(transport_name): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), - transport=transport_name, - ) - assert client.transport._host == ( - 'generativelanguage.googleapis.com:8000' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com:8000' - ) - -@pytest.mark.parametrize("transport_name", [ - "rest", -]) -def test_text_service_client_transport_session_collision(transport_name): - creds1 = ga_credentials.AnonymousCredentials() - creds2 = ga_credentials.AnonymousCredentials() - client1 = TextServiceClient( - credentials=creds1, - transport=transport_name, - ) - client2 = TextServiceClient( - credentials=creds2, - transport=transport_name, - ) - session1 = client1.transport.generate_text._session - session2 = client2.transport.generate_text._session - assert session1 != session2 - session1 = client1.transport.embed_text._session - session2 = client2.transport.embed_text._session - assert session1 != session2 -def test_text_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) - - # Check that channel is used if provided. - transport = transports.TextServiceGrpcTransport( - host="squid.clam.whelk", - channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - assert transport._ssl_channel_credentials == None - - -def test_text_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) - - # Check that channel is used if provided. - transport = transports.TextServiceGrpcAsyncIOTransport( - host="squid.clam.whelk", - channel=channel, - ) - assert transport.grpc_channel == channel - assert transport._host == "squid.clam.whelk:443" - assert transport._ssl_channel_credentials == None - - -# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are -# removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) -def test_text_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: - mock_ssl_cred = mock.Mock() - grpc_ssl_channel_cred.return_value = mock_ssl_cred - - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - - cred = ga_credentials.AnonymousCredentials() - with pytest.warns(DeprecationWarning): - with mock.patch.object(google.auth, 'default') as adc: - adc.return_value = (cred, None) - transport = transport_class( - host="squid.clam.whelk", - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=client_cert_source_callback, - ) - adc.assert_called_once() - - grpc_ssl_channel_cred.assert_called_once_with( - certificate_chain=b"cert bytes", private_key=b"key bytes" - ) - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - assert transport.grpc_channel == mock_grpc_channel - assert transport._ssl_channel_credentials == mock_ssl_cred - - -# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are -# removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) -def test_text_service_transport_channel_mtls_with_adc( - transport_class -): - mock_ssl_cred = mock.Mock() - with mock.patch.multiple( - "google.auth.transport.grpc.SslCredentials", - __init__=mock.Mock(return_value=None), - ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), - ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: - mock_grpc_channel = mock.Mock() - grpc_create_channel.return_value = mock_grpc_channel - mock_cred = mock.Mock() - - with pytest.warns(DeprecationWarning): - transport = transport_class( - host="squid.clam.whelk", - credentials=mock_cred, - api_mtls_endpoint="mtls.squid.clam.whelk", - client_cert_source=None, - ) - - grpc_create_channel.assert_called_once_with( - "mtls.squid.clam.whelk:443", - credentials=mock_cred, - credentials_file=None, - scopes=None, - ssl_credentials=mock_ssl_cred, - quota_project_id=None, - options=[ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ], - ) - assert transport.grpc_channel == mock_grpc_channel - - -def test_model_path(): - model = "squid" - expected = "models/{model}".format(model=model, ) - actual = TextServiceClient.model_path(model) - assert expected == actual - - -def test_parse_model_path(): - expected = { - "model": "clam", - } - path = TextServiceClient.model_path(**expected) - - # Check that the path construction is reversible. - actual = TextServiceClient.parse_model_path(path) - assert expected == actual - -def test_common_billing_account_path(): - billing_account = "whelk" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) - actual = TextServiceClient.common_billing_account_path(billing_account) - assert expected == actual - - -def test_parse_common_billing_account_path(): - expected = { - "billing_account": "octopus", - } - path = TextServiceClient.common_billing_account_path(**expected) - - # Check that the path construction is reversible. - actual = TextServiceClient.parse_common_billing_account_path(path) - assert expected == actual - -def test_common_folder_path(): - folder = "oyster" - expected = "folders/{folder}".format(folder=folder, ) - actual = TextServiceClient.common_folder_path(folder) - assert expected == actual - - -def test_parse_common_folder_path(): - expected = { - "folder": "nudibranch", - } - path = TextServiceClient.common_folder_path(**expected) - - # Check that the path construction is reversible. - actual = TextServiceClient.parse_common_folder_path(path) - assert expected == actual - -def test_common_organization_path(): - organization = "cuttlefish" - expected = "organizations/{organization}".format(organization=organization, ) - actual = TextServiceClient.common_organization_path(organization) - assert expected == actual - - -def test_parse_common_organization_path(): - expected = { - "organization": "mussel", - } - path = TextServiceClient.common_organization_path(**expected) - - # Check that the path construction is reversible. - actual = TextServiceClient.parse_common_organization_path(path) - assert expected == actual - -def test_common_project_path(): - project = "winkle" - expected = "projects/{project}".format(project=project, ) - actual = TextServiceClient.common_project_path(project) - assert expected == actual - - -def test_parse_common_project_path(): - expected = { - "project": "nautilus", - } - path = TextServiceClient.common_project_path(**expected) - - # Check that the path construction is reversible. - actual = TextServiceClient.parse_common_project_path(path) - assert expected == actual - -def test_common_location_path(): - project = "scallop" - location = "abalone" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) - actual = TextServiceClient.common_location_path(project, location) - assert expected == actual - - -def test_parse_common_location_path(): - expected = { - "project": "squid", - "location": "clam", - } - path = TextServiceClient.common_location_path(**expected) - - # Check that the path construction is reversible. - actual = TextServiceClient.parse_common_location_path(path) - assert expected == actual - - -def test_client_with_default_client_info(): - client_info = gapic_v1.client_info.ClientInfo() - - with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - client_info=client_info, - ) - prep.assert_called_once_with(client_info) - - with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: - transport_class = TextServiceClient.get_transport_class() - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials(), - client_info=client_info, - ) - prep.assert_called_once_with(client_info) - -@pytest.mark.asyncio -async def test_transport_close_async(): - client = TextServiceAsyncClient( - credentials=ga_credentials.AnonymousCredentials(), - transport="grpc_asyncio", - ) - with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: - async with client: - close.assert_not_called() - close.assert_called_once() - - -def test_transport_close(): - transports = { - "rest": "_session", - "grpc": "_grpc_channel", - } - - for transport, close_name in transports.items(): - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport - ) - with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: - with client: - close.assert_not_called() - close.assert_called_once() - -def test_client_ctx(): - transports = [ - 'rest', - 'grpc', - ] - for transport in transports: - client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport - ) - # Test client calls underlying transport. - with mock.patch.object(type(client.transport), "close") as close: - close.assert_not_called() - with client: - pass - close.assert_called() - -@pytest.mark.parametrize("client_class,transport_class", [ - (TextServiceClient, transports.TextServiceGrpcTransport), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport), -]) -def test_api_key_credentials(client_class, transport_class): - with mock.patch.object( - google.auth._default, "get_api_key_credentials", create=True - ) as get_api_key_credentials: - mock_cred = mock.Mock() - get_api_key_credentials.return_value = mock_cred - options = client_options.ClientOptions() - options.api_key = "api_key" - with mock.patch.object(transport_class, "__init__") as patched: - patched.return_value = None - client = client_class(client_options=options) - patched.assert_called_once_with( - credentials=mock_cred, - credentials_file=None, - host=client.DEFAULT_ENDPOINT, - scopes=None, - client_cert_source_for_mtls=None, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - always_use_jwt_access=True, - api_audience=None, - ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc b/owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc deleted file mode 100644 index fd060ae956b5..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/.coveragerc +++ /dev/null @@ -1,13 +0,0 @@ -[run] -branch = True - -[report] -show_missing = True -omit = - google/ai/generativelanguage/__init__.py - google/ai/generativelanguage/gapic_version.py -exclude_lines = - # Re-enable the standard pragma - pragma: NO COVER - # Ignore debug-only repr - def __repr__ diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 b/owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 deleted file mode 100644 index 29227d4cf419..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/.flake8 +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Generated by synthtool. DO NOT EDIT! -[flake8] -ignore = E203, E266, E501, W503 -exclude = - # Exclude generated code. - **/proto/** - **/gapic/** - **/services/** - **/types/** - *_pb2.py - - # Standard linting exemptions. - **/.nox/** - __pycache__, - .git, - *.pyc, - conf.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in b/owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in deleted file mode 100644 index a41cec0defac..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/MANIFEST.in +++ /dev/null @@ -1,2 +0,0 @@ -recursive-include google/ai/generativelanguage *.py -recursive-include google/ai/generativelanguage_v1beta3 *.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst deleted file mode 100644 index 099f73894711..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/README.rst +++ /dev/null @@ -1,49 +0,0 @@ -Python Client for Google Ai Generativelanguage API -================================================= - -Quick Start ------------ - -In order to use this library, you first need to go through the following steps: - -1. `Select or create a Cloud Platform project.`_ -2. `Enable billing for your project.`_ -3. Enable the Google Ai Generativelanguage API. -4. `Setup Authentication.`_ - -.. _Select or create a Cloud Platform project.: https://console.cloud.google.com/project -.. _Enable billing for your project.: https://cloud.google.com/billing/docs/how-to/modify-project#enable_billing_for_a_project -.. _Setup Authentication.: https://googleapis.dev/python/google-api-core/latest/auth.html - -Installation -~~~~~~~~~~~~ - -Install this library in a `virtualenv`_ using pip. `virtualenv`_ is a tool to -create isolated Python environments. The basic problem it addresses is one of -dependencies and versions, and indirectly permissions. - -With `virtualenv`_, it's possible to install this library without needing system -install permissions, and without clashing with the installed system -dependencies. - -.. _`virtualenv`: https://virtualenv.pypa.io/en/latest/ - - -Mac/Linux -^^^^^^^^^ - -.. code-block:: console - - python3 -m venv - source /bin/activate - /bin/pip install /path/to/library - - -Windows -^^^^^^^ - -.. code-block:: console - - python3 -m venv - \Scripts\activate - \Scripts\pip.exe install \path\to\library diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css deleted file mode 100644 index 06423be0b592..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/_static/custom.css +++ /dev/null @@ -1,3 +0,0 @@ -dl.field-list > dt { - min-width: 100px -} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py deleted file mode 100644 index 0f3f4903ff54..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/conf.py +++ /dev/null @@ -1,376 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -# google-ai-generativelanguage documentation build configuration file -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -import sys -import os -import shlex - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath("..")) - -__version__ = "0.1.0" - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = "4.0.1" - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.autosummary", - "sphinx.ext.intersphinx", - "sphinx.ext.coverage", - "sphinx.ext.napoleon", - "sphinx.ext.todo", - "sphinx.ext.viewcode", -] - -# autodoc/autosummary flags -autoclass_content = "both" -autodoc_default_flags = ["members"] -autosummary_generate = True - - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# Allow markdown includes (so releases.md can include CHANGLEOG.md) -# http://www.sphinx-doc.org/en/master/markdown.html -source_parsers = {".md": "recommonmark.parser.CommonMarkParser"} - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -source_suffix = [".rst", ".md"] - -# The encoding of source files. -# source_encoding = 'utf-8-sig' - -# The root toctree document. -root_doc = "index" - -# General information about the project. -project = u"google-ai-generativelanguage" -copyright = u"2023, Google, LLC" -author = u"Google APIs" # TODO: autogenerate this bit - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The full version, including alpha/beta/rc tags. -release = __version__ -# The short X.Y version. -version = ".".join(release.split(".")[0:2]) - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = 'en' - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -# today = '' -# Else, today_fmt is used as the format for a strftime call. -# today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = ["_build"] - -# The reST default role (used for this markup: `text`) to use for all -# documents. -# default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -# add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -# add_module_names = True - -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -# show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "sphinx" - -# A list of ignored prefixes for module index sorting. -# modindex_common_prefix = [] - -# If true, keep warnings as "system message" paragraphs in the built documents. -# keep_warnings = False - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = True - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = "alabaster" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -html_theme_options = { - "description": "Google Ai Client Libraries for Python", - "github_user": "googleapis", - "github_repo": "google-cloud-python", - "github_banner": True, - "font_family": "'Roboto', Georgia, sans", - "head_font_family": "'Roboto', Georgia, serif", - "code_font_family": "'Roboto Mono', 'Consolas', monospace", -} - -# Add any paths that contain custom themes here, relative to this directory. -# html_theme_path = [] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -# html_title = None - -# A shorter title for the navigation bar. Default is the same as html_title. -# html_short_title = None - -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. -# html_logo = None - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. -# html_favicon = None - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -# html_extra_path = [] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -# html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -# html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -# html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -# html_additional_pages = {} - -# If false, no module index is generated. -# html_domain_indices = True - -# If false, no index is generated. -# html_use_index = True - -# If true, the index is split into individual pages for each letter. -# html_split_index = False - -# If true, links to the reST sources are added to the pages. -# html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -# html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -# html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -# html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -# html_file_suffix = None - -# Language to be used for generating the HTML full-text search index. -# Sphinx supports the following languages: -# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' -# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -# html_search_language = 'en' - -# A dictionary with options for the search language support, empty by default. -# Now only 'ja' uses this config value -# html_search_options = {'type': 'default'} - -# The name of a javascript file (relative to the configuration directory) that -# implements a search results scorer. If empty, the default will be used. -# html_search_scorer = 'scorer.js' - -# Output file base name for HTML help builder. -htmlhelp_basename = "google-ai-generativelanguage-doc" - -# -- Options for warnings ------------------------------------------------------ - - -suppress_warnings = [ - # Temporarily suppress this to avoid "more than one target found for - # cross-reference" warning, which are intractable for us to avoid while in - # a mono-repo. - # See https://github.com/sphinx-doc/sphinx/blob - # /2a65ffeef5c107c19084fabdd706cdff3f52d93c/sphinx/domains/python.py#L843 - "ref.python" -] - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # 'preamble': '', - # Latex figure (float) alignment - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - ( - root_doc, - "google-ai-generativelanguage.tex", - u"google-ai-generativelanguage Documentation", - author, - "manual", - ) -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -# latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -# latex_use_parts = False - -# If true, show page references after internal links. -# latex_show_pagerefs = False - -# If true, show URL addresses after external links. -# latex_show_urls = False - -# Documents to append as an appendix to all manuals. -# latex_appendices = [] - -# If false, no module index is generated. -# latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - ( - root_doc, - "google-ai-generativelanguage", - u"Google Ai Generativelanguage Documentation", - [author], - 1, - ) -] - -# If true, show URL addresses after external links. -# man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ( - root_doc, - "google-ai-generativelanguage", - u"google-ai-generativelanguage Documentation", - author, - "google-ai-generativelanguage", - "GAPIC library for Google Ai Generativelanguage API", - "APIs", - ) -] - -# Documents to append as an appendix to all manuals. -# texinfo_appendices = [] - -# If false, no module index is generated. -# texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -# texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -# texinfo_no_detailmenu = False - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - "python": ("http://python.readthedocs.org/en/latest/", None), - "gax": ("https://gax-python.readthedocs.org/en/latest/", None), - "google-auth": ("https://google-auth.readthedocs.io/en/stable", None), - "google-gax": ("https://gax-python.readthedocs.io/en/latest/", None), - "google.api_core": ("https://googleapis.dev/python/google-api-core/latest/", None), - "grpc": ("https://grpc.io/grpc/python/", None), - "requests": ("http://requests.kennethreitz.org/en/stable/", None), - "proto": ("https://proto-plus-python.readthedocs.io/en/stable", None), - "protobuf": ("https://googleapis.dev/python/protobuf/latest/", None), -} - - -# Napoleon settings -napoleon_google_docstring = True -napoleon_numpy_docstring = True -napoleon_include_private_with_doc = False -napoleon_include_special_with_doc = True -napoleon_use_admonition_for_examples = False -napoleon_use_admonition_for_notes = False -napoleon_use_admonition_for_references = False -napoleon_use_ivar = False -napoleon_use_param = True -napoleon_use_rtype = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst b/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst deleted file mode 100644 index d08223c1a59b..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/index.rst +++ /dev/null @@ -1,7 +0,0 @@ -API Reference -------------- -.. toctree:: - :maxdepth: 2 - - generativelanguage_v1beta3/services - generativelanguage_v1beta3/types diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py deleted file mode 100644 index 77d8cbc1869c..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/__init__.py +++ /dev/null @@ -1,145 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from google.ai.generativelanguage import gapic_version as package_version - -__version__ = package_version.__version__ - - -from google.ai.generativelanguage_v1beta3.services.discuss_service.client import DiscussServiceClient -from google.ai.generativelanguage_v1beta3.services.discuss_service.async_client import DiscussServiceAsyncClient -from google.ai.generativelanguage_v1beta3.services.model_service.client import ModelServiceClient -from google.ai.generativelanguage_v1beta3.services.model_service.async_client import ModelServiceAsyncClient -from google.ai.generativelanguage_v1beta3.services.permission_service.client import PermissionServiceClient -from google.ai.generativelanguage_v1beta3.services.permission_service.async_client import PermissionServiceAsyncClient -from google.ai.generativelanguage_v1beta3.services.text_service.client import TextServiceClient -from google.ai.generativelanguage_v1beta3.services.text_service.async_client import TextServiceAsyncClient - -from google.ai.generativelanguage_v1beta3.types.citation import CitationMetadata -from google.ai.generativelanguage_v1beta3.types.citation import CitationSource -from google.ai.generativelanguage_v1beta3.types.discuss_service import CountMessageTokensRequest -from google.ai.generativelanguage_v1beta3.types.discuss_service import CountMessageTokensResponse -from google.ai.generativelanguage_v1beta3.types.discuss_service import Example -from google.ai.generativelanguage_v1beta3.types.discuss_service import GenerateMessageRequest -from google.ai.generativelanguage_v1beta3.types.discuss_service import GenerateMessageResponse -from google.ai.generativelanguage_v1beta3.types.discuss_service import Message -from google.ai.generativelanguage_v1beta3.types.discuss_service import MessagePrompt -from google.ai.generativelanguage_v1beta3.types.model import Model -from google.ai.generativelanguage_v1beta3.types.model_service import CreateTunedModelMetadata -from google.ai.generativelanguage_v1beta3.types.model_service import CreateTunedModelRequest -from google.ai.generativelanguage_v1beta3.types.model_service import DeleteTunedModelRequest -from google.ai.generativelanguage_v1beta3.types.model_service import GetModelRequest -from google.ai.generativelanguage_v1beta3.types.model_service import GetTunedModelRequest -from google.ai.generativelanguage_v1beta3.types.model_service import ListModelsRequest -from google.ai.generativelanguage_v1beta3.types.model_service import ListModelsResponse -from google.ai.generativelanguage_v1beta3.types.model_service import ListTunedModelsRequest -from google.ai.generativelanguage_v1beta3.types.model_service import ListTunedModelsResponse -from google.ai.generativelanguage_v1beta3.types.model_service import UpdateTunedModelRequest -from google.ai.generativelanguage_v1beta3.types.permission import Permission -from google.ai.generativelanguage_v1beta3.types.permission_service import CreatePermissionRequest -from google.ai.generativelanguage_v1beta3.types.permission_service import DeletePermissionRequest -from google.ai.generativelanguage_v1beta3.types.permission_service import GetPermissionRequest -from google.ai.generativelanguage_v1beta3.types.permission_service import ListPermissionsRequest -from google.ai.generativelanguage_v1beta3.types.permission_service import ListPermissionsResponse -from google.ai.generativelanguage_v1beta3.types.permission_service import TransferOwnershipRequest -from google.ai.generativelanguage_v1beta3.types.permission_service import TransferOwnershipResponse -from google.ai.generativelanguage_v1beta3.types.permission_service import UpdatePermissionRequest -from google.ai.generativelanguage_v1beta3.types.safety import ContentFilter -from google.ai.generativelanguage_v1beta3.types.safety import SafetyFeedback -from google.ai.generativelanguage_v1beta3.types.safety import SafetyRating -from google.ai.generativelanguage_v1beta3.types.safety import SafetySetting -from google.ai.generativelanguage_v1beta3.types.safety import HarmCategory -from google.ai.generativelanguage_v1beta3.types.text_service import BatchEmbedTextRequest -from google.ai.generativelanguage_v1beta3.types.text_service import BatchEmbedTextResponse -from google.ai.generativelanguage_v1beta3.types.text_service import CountTextTokensRequest -from google.ai.generativelanguage_v1beta3.types.text_service import CountTextTokensResponse -from google.ai.generativelanguage_v1beta3.types.text_service import Embedding -from google.ai.generativelanguage_v1beta3.types.text_service import EmbedTextRequest -from google.ai.generativelanguage_v1beta3.types.text_service import EmbedTextResponse -from google.ai.generativelanguage_v1beta3.types.text_service import GenerateTextRequest -from google.ai.generativelanguage_v1beta3.types.text_service import GenerateTextResponse -from google.ai.generativelanguage_v1beta3.types.text_service import TextCompletion -from google.ai.generativelanguage_v1beta3.types.text_service import TextPrompt -from google.ai.generativelanguage_v1beta3.types.tuned_model import Dataset -from google.ai.generativelanguage_v1beta3.types.tuned_model import Hyperparameters -from google.ai.generativelanguage_v1beta3.types.tuned_model import TunedModel -from google.ai.generativelanguage_v1beta3.types.tuned_model import TunedModelSource -from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningExample -from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningExamples -from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningSnapshot -from google.ai.generativelanguage_v1beta3.types.tuned_model import TuningTask - -__all__ = ('DiscussServiceClient', - 'DiscussServiceAsyncClient', - 'ModelServiceClient', - 'ModelServiceAsyncClient', - 'PermissionServiceClient', - 'PermissionServiceAsyncClient', - 'TextServiceClient', - 'TextServiceAsyncClient', - 'CitationMetadata', - 'CitationSource', - 'CountMessageTokensRequest', - 'CountMessageTokensResponse', - 'Example', - 'GenerateMessageRequest', - 'GenerateMessageResponse', - 'Message', - 'MessagePrompt', - 'Model', - 'CreateTunedModelMetadata', - 'CreateTunedModelRequest', - 'DeleteTunedModelRequest', - 'GetModelRequest', - 'GetTunedModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'ListTunedModelsRequest', - 'ListTunedModelsResponse', - 'UpdateTunedModelRequest', - 'Permission', - 'CreatePermissionRequest', - 'DeletePermissionRequest', - 'GetPermissionRequest', - 'ListPermissionsRequest', - 'ListPermissionsResponse', - 'TransferOwnershipRequest', - 'TransferOwnershipResponse', - 'UpdatePermissionRequest', - 'ContentFilter', - 'SafetyFeedback', - 'SafetyRating', - 'SafetySetting', - 'HarmCategory', - 'BatchEmbedTextRequest', - 'BatchEmbedTextResponse', - 'CountTextTokensRequest', - 'CountTextTokensResponse', - 'Embedding', - 'EmbedTextRequest', - 'EmbedTextResponse', - 'GenerateTextRequest', - 'GenerateTextResponse', - 'TextCompletion', - 'TextPrompt', - 'Dataset', - 'Hyperparameters', - 'TunedModel', - 'TunedModelSource', - 'TuningExample', - 'TuningExamples', - 'TuningSnapshot', - 'TuningTask', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py deleted file mode 100644 index 360a0d13ebdd..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/gapic_version.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -__version__ = "0.0.0" # {x-release-please-version} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed deleted file mode 100644 index 38773eee6363..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage/py.typed +++ /dev/null @@ -1,2 +0,0 @@ -# Marker file for PEP 561. -# The google-ai-generativelanguage package uses inline types. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py deleted file mode 100644 index 264895e674fe..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/__init__.py +++ /dev/null @@ -1,146 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version - -__version__ = package_version.__version__ - - -from .services.discuss_service import DiscussServiceClient -from .services.discuss_service import DiscussServiceAsyncClient -from .services.model_service import ModelServiceClient -from .services.model_service import ModelServiceAsyncClient -from .services.permission_service import PermissionServiceClient -from .services.permission_service import PermissionServiceAsyncClient -from .services.text_service import TextServiceClient -from .services.text_service import TextServiceAsyncClient - -from .types.citation import CitationMetadata -from .types.citation import CitationSource -from .types.discuss_service import CountMessageTokensRequest -from .types.discuss_service import CountMessageTokensResponse -from .types.discuss_service import Example -from .types.discuss_service import GenerateMessageRequest -from .types.discuss_service import GenerateMessageResponse -from .types.discuss_service import Message -from .types.discuss_service import MessagePrompt -from .types.model import Model -from .types.model_service import CreateTunedModelMetadata -from .types.model_service import CreateTunedModelRequest -from .types.model_service import DeleteTunedModelRequest -from .types.model_service import GetModelRequest -from .types.model_service import GetTunedModelRequest -from .types.model_service import ListModelsRequest -from .types.model_service import ListModelsResponse -from .types.model_service import ListTunedModelsRequest -from .types.model_service import ListTunedModelsResponse -from .types.model_service import UpdateTunedModelRequest -from .types.permission import Permission -from .types.permission_service import CreatePermissionRequest -from .types.permission_service import DeletePermissionRequest -from .types.permission_service import GetPermissionRequest -from .types.permission_service import ListPermissionsRequest -from .types.permission_service import ListPermissionsResponse -from .types.permission_service import TransferOwnershipRequest -from .types.permission_service import TransferOwnershipResponse -from .types.permission_service import UpdatePermissionRequest -from .types.safety import ContentFilter -from .types.safety import SafetyFeedback -from .types.safety import SafetyRating -from .types.safety import SafetySetting -from .types.safety import HarmCategory -from .types.text_service import BatchEmbedTextRequest -from .types.text_service import BatchEmbedTextResponse -from .types.text_service import CountTextTokensRequest -from .types.text_service import CountTextTokensResponse -from .types.text_service import Embedding -from .types.text_service import EmbedTextRequest -from .types.text_service import EmbedTextResponse -from .types.text_service import GenerateTextRequest -from .types.text_service import GenerateTextResponse -from .types.text_service import TextCompletion -from .types.text_service import TextPrompt -from .types.tuned_model import Dataset -from .types.tuned_model import Hyperparameters -from .types.tuned_model import TunedModel -from .types.tuned_model import TunedModelSource -from .types.tuned_model import TuningExample -from .types.tuned_model import TuningExamples -from .types.tuned_model import TuningSnapshot -from .types.tuned_model import TuningTask - -__all__ = ( - 'DiscussServiceAsyncClient', - 'ModelServiceAsyncClient', - 'PermissionServiceAsyncClient', - 'TextServiceAsyncClient', -'BatchEmbedTextRequest', -'BatchEmbedTextResponse', -'CitationMetadata', -'CitationSource', -'ContentFilter', -'CountMessageTokensRequest', -'CountMessageTokensResponse', -'CountTextTokensRequest', -'CountTextTokensResponse', -'CreatePermissionRequest', -'CreateTunedModelMetadata', -'CreateTunedModelRequest', -'Dataset', -'DeletePermissionRequest', -'DeleteTunedModelRequest', -'DiscussServiceClient', -'EmbedTextRequest', -'EmbedTextResponse', -'Embedding', -'Example', -'GenerateMessageRequest', -'GenerateMessageResponse', -'GenerateTextRequest', -'GenerateTextResponse', -'GetModelRequest', -'GetPermissionRequest', -'GetTunedModelRequest', -'HarmCategory', -'Hyperparameters', -'ListModelsRequest', -'ListModelsResponse', -'ListPermissionsRequest', -'ListPermissionsResponse', -'ListTunedModelsRequest', -'ListTunedModelsResponse', -'Message', -'MessagePrompt', -'Model', -'ModelServiceClient', -'Permission', -'PermissionServiceClient', -'SafetyFeedback', -'SafetyRating', -'SafetySetting', -'TextCompletion', -'TextPrompt', -'TextServiceClient', -'TransferOwnershipRequest', -'TransferOwnershipResponse', -'TunedModel', -'TunedModelSource', -'TuningExample', -'TuningExamples', -'TuningSnapshot', -'TuningTask', -'UpdatePermissionRequest', -'UpdateTunedModelRequest', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py deleted file mode 100644 index 360a0d13ebdd..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_version.py +++ /dev/null @@ -1,16 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -__version__ = "0.0.0" # {x-release-please-version} diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed deleted file mode 100644 index 38773eee6363..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/py.typed +++ /dev/null @@ -1,2 +0,0 @@ -# Marker file for PEP 561. -# The google-ai-generativelanguage package uses inline types. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py deleted file mode 100644 index b585c1ce424c..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -from typing import Dict, Type - -from .base import DiscussServiceTransport -from .grpc import DiscussServiceGrpcTransport -from .grpc_asyncio import DiscussServiceGrpcAsyncIOTransport -from .rest import DiscussServiceRestTransport -from .rest import DiscussServiceRestInterceptor - - -# Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] -_transport_registry['grpc'] = DiscussServiceGrpcTransport -_transport_registry['grpc_asyncio'] = DiscussServiceGrpcAsyncIOTransport -_transport_registry['rest'] = DiscussServiceRestTransport - -__all__ = ( - 'DiscussServiceTransport', - 'DiscussServiceGrpcTransport', - 'DiscussServiceGrpcAsyncIOTransport', - 'DiscussServiceRestTransport', - 'DiscussServiceRestInterceptor', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py deleted file mode 100644 index c51cadf4ba09..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -from typing import Dict, Type - -from .base import ModelServiceTransport -from .grpc import ModelServiceGrpcTransport -from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport -from .rest import ModelServiceRestTransport -from .rest import ModelServiceRestInterceptor - - -# Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry['grpc'] = ModelServiceGrpcTransport -_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport -_transport_registry['rest'] = ModelServiceRestTransport - -__all__ = ( - 'ModelServiceTransport', - 'ModelServiceGrpcTransport', - 'ModelServiceGrpcAsyncIOTransport', - 'ModelServiceRestTransport', - 'ModelServiceRestInterceptor', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py deleted file mode 100644 index f167a9c3175d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from .client import TextServiceClient -from .async_client import TextServiceAsyncClient - -__all__ = ( - 'TextServiceClient', - 'TextServiceAsyncClient', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py deleted file mode 100644 index 71e949c7a4f5..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from collections import OrderedDict -from typing import Dict, Type - -from .base import TextServiceTransport -from .grpc import TextServiceGrpcTransport -from .grpc_asyncio import TextServiceGrpcAsyncIOTransport -from .rest import TextServiceRestTransport -from .rest import TextServiceRestInterceptor - - -# Compile a registry of transports. -_transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] -_transport_registry['grpc'] = TextServiceGrpcTransport -_transport_registry['grpc_asyncio'] = TextServiceGrpcAsyncIOTransport -_transport_registry['rest'] = TextServiceRestTransport - -__all__ = ( - 'TextServiceTransport', - 'TextServiceGrpcTransport', - 'TextServiceGrpcAsyncIOTransport', - 'TextServiceRestTransport', - 'TextServiceRestInterceptor', -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini b/owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini deleted file mode 100644 index 574c5aed394b..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/mypy.ini +++ /dev/null @@ -1,3 +0,0 @@ -[mypy] -python_version = 3.7 -namespace_packages = True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py deleted file mode 100644 index 66bac3a254b7..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/noxfile.py +++ /dev/null @@ -1,184 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -import pathlib -import shutil -import subprocess -import sys - - -import nox # type: ignore - -ALL_PYTHON = [ - "3.7", - "3.8", - "3.9", - "3.10", - "3.11", -] - -CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute() - -LOWER_BOUND_CONSTRAINTS_FILE = CURRENT_DIRECTORY / "constraints.txt" -PACKAGE_NAME = subprocess.check_output([sys.executable, "setup.py", "--name"], encoding="utf-8") - -BLACK_VERSION = "black==22.3.0" -BLACK_PATHS = ["docs", "google", "tests", "samples", "noxfile.py", "setup.py"] -DEFAULT_PYTHON_VERSION = "3.11" - -nox.sessions = [ - "unit", - "cover", - "mypy", - "check_lower_bounds" - # exclude update_lower_bounds from default - "docs", - "blacken", - "lint", - "lint_setup_py", -] - -@nox.session(python=ALL_PYTHON) -def unit(session): - """Run the unit test suite.""" - - session.install('coverage', 'pytest', 'pytest-cov', 'pytest-asyncio', 'asyncmock; python_version < "3.8"') - session.install('-e', '.') - - session.run( - 'py.test', - '--quiet', - '--cov=google/ai/generativelanguage_v1beta3/', - '--cov=tests/', - '--cov-config=.coveragerc', - '--cov-report=term', - '--cov-report=html', - os.path.join('tests', 'unit', ''.join(session.posargs)) - ) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def cover(session): - """Run the final coverage report. - This outputs the coverage report aggregating coverage from the unit - test runs (not system test runs), and then erases coverage data. - """ - session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=100") - - session.run("coverage", "erase") - - -@nox.session(python=ALL_PYTHON) -def mypy(session): - """Run the type checker.""" - session.install( - 'mypy', - 'types-requests', - 'types-protobuf' - ) - session.install('.') - session.run( - 'mypy', - '--explicit-package-bases', - 'google', - ) - - -@nox.session -def update_lower_bounds(session): - """Update lower bounds in constraints.txt to match setup.py""" - session.install('google-cloud-testutils') - session.install('.') - - session.run( - 'lower-bound-checker', - 'update', - '--package-name', - PACKAGE_NAME, - '--constraints-file', - str(LOWER_BOUND_CONSTRAINTS_FILE), - ) - - -@nox.session -def check_lower_bounds(session): - """Check lower bounds in setup.py are reflected in constraints file""" - session.install('google-cloud-testutils') - session.install('.') - - session.run( - 'lower-bound-checker', - 'check', - '--package-name', - PACKAGE_NAME, - '--constraints-file', - str(LOWER_BOUND_CONSTRAINTS_FILE), - ) - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def docs(session): - """Build the docs for this library.""" - - session.install("-e", ".") - session.install("sphinx==7.0.1", "alabaster", "recommonmark") - - shutil.rmtree(os.path.join("docs", "_build"), ignore_errors=True) - session.run( - "sphinx-build", - "-W", # warnings as errors - "-T", # show full traceback on exception - "-N", # no colors - "-b", - "html", - "-d", - os.path.join("docs", "_build", "doctrees", ""), - os.path.join("docs", ""), - os.path.join("docs", "_build", "html", ""), - ) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def lint(session): - """Run linters. - - Returns a failure if the linters find linting errors or sufficiently - serious code quality issues. - """ - session.install("flake8", BLACK_VERSION) - session.run( - "black", - "--check", - *BLACK_PATHS, - ) - session.run("flake8", "google", "tests", "samples") - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def blacken(session): - """Run black. Format code to uniform standard.""" - session.install(BLACK_VERSION) - session.run( - "black", - *BLACK_PATHS, - ) - - -@nox.session(python=DEFAULT_PYTHON_VERSION) -def lint_setup_py(session): - """Verify that setup.py is valid (including RST check).""" - session.install("docutils", "pygments") - session.run("python", "setup.py", "check", "--restructuredtext", "--strict") diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py deleted file mode 100644 index 0e0b1e55d45f..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/setup.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import io -import os - -import setuptools # type: ignore - -package_root = os.path.abspath(os.path.dirname(__file__)) - -name = 'google-ai-generativelanguage' - - -description = "Google Ai Generativelanguage API client library" - -version = {} -with open(os.path.join(package_root, 'google/ai/generativelanguage/gapic_version.py')) as fp: - exec(fp.read(), version) -version = version["__version__"] - -if version[0] == "0": - release_status = "Development Status :: 4 - Beta" -else: - release_status = "Development Status :: 5 - Production/Stable" - -dependencies = [ - "google-api-core[grpc] >= 1.34.0, <3.0.0dev,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*,!=2.10.*", - "proto-plus >= 1.22.0, <2.0.0dev", - "proto-plus >= 1.22.2, <2.0.0dev; python_version>='3.11'", - "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", -] -url = "https://github.com/googleapis/python-ai-generativelanguage" - -package_root = os.path.abspath(os.path.dirname(__file__)) - -readme_filename = os.path.join(package_root, "README.rst") -with io.open(readme_filename, encoding="utf-8") as readme_file: - readme = readme_file.read() - -packages = [ - package - for package in setuptools.PEP420PackageFinder.find() - if package.startswith("google") -] - -namespaces = ["google", "google.ai"] - -setuptools.setup( - name=name, - version=version, - description=description, - long_description=readme, - author="Google LLC", - author_email="googleapis-packages@google.com", - license="Apache 2.0", - url=url, - classifiers=[ - release_status, - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Operating System :: OS Independent", - "Topic :: Internet", - ], - platforms="Posix; MacOS X; Windows", - packages=packages, - python_requires=">=3.7", - namespace_packages=namespaces, - install_requires=dependencies, - include_package_data=True, - zip_safe=False, -) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.10.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.11.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.12.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt deleted file mode 100644 index 6c44adfea7ee..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.7.txt +++ /dev/null @@ -1,9 +0,0 @@ -# This constraints file is used to check that lower bounds -# are correct in setup.py -# List all library dependencies and extras in this file. -# Pin the version to the lower bound. -# e.g., if setup.py has "google-cloud-foo >= 1.14.0, < 2.0.0dev", -# Then this file should have google-cloud-foo==1.14.0 -google-api-core==1.34.0 -proto-plus==1.22.0 -protobuf==3.19.5 diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.8.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt b/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt deleted file mode 100644 index ed7f9aed2559..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/testing/constraints-3.9.txt +++ /dev/null @@ -1,6 +0,0 @@ -# -*- coding: utf-8 -*- -# This constraints file is required for unit tests. -# List all library dependencies and extras in this file. -google-api-core -proto-plus -protobuf diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py b/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py deleted file mode 100644 index 1b4db446eb8d..000000000000 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ - -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# diff --git a/packages/google-ai-generativelanguage/.flake8 b/packages/google-ai-generativelanguage/.flake8 index 2e438749863d..87f6e408c47d 100644 --- a/packages/google-ai-generativelanguage/.flake8 +++ b/packages/google-ai-generativelanguage/.flake8 @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2020 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/packages/google-ai-generativelanguage/CONTRIBUTING.rst b/packages/google-ai-generativelanguage/CONTRIBUTING.rst index 8bbcfb7cc7eb..0dd4cd7df3d3 100644 --- a/packages/google-ai-generativelanguage/CONTRIBUTING.rst +++ b/packages/google-ai-generativelanguage/CONTRIBUTING.rst @@ -236,7 +236,7 @@ We support: Supported versions can be found in our ``noxfile.py`` `config`_. -.. _config: https://github.com/googleapis/google-cloud-python/blob/main/noxfile.py +.. _config: https://github.com/googleapis/google-cloud-python/blob/main/packages/google-ai-generativelanguage/noxfile.py ********** diff --git a/packages/google-ai-generativelanguage/MANIFEST.in b/packages/google-ai-generativelanguage/MANIFEST.in index e783f4c6209b..e0a66705318e 100644 --- a/packages/google-ai-generativelanguage/MANIFEST.in +++ b/packages/google-ai-generativelanguage/MANIFEST.in @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2020 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/packages/google-ai-generativelanguage/README.rst b/packages/google-ai-generativelanguage/README.rst index 94ee778b3480..dbad3c6a8c25 100644 --- a/packages/google-ai-generativelanguage/README.rst +++ b/packages/google-ai-generativelanguage/README.rst @@ -36,21 +36,24 @@ In order to use this library, you first need to go through the following steps: Installation ~~~~~~~~~~~~ -Install this library in a `virtualenv`_ using pip. `virtualenv`_ is a tool to -create isolated Python environments. The basic problem it addresses is one of -dependencies and versions, and indirectly permissions. +Install this library in a virtual environment using `venv`_. `venv`_ is a tool that +creates isolated Python environments. These isolated environments can have separate +versions of Python packages, which allows you to isolate one project's dependencies +from the dependencies of other projects. -With `virtualenv`_, it's possible to install this library without needing system +With `venv`_, it's possible to install this library without needing system install permissions, and without clashing with the installed system dependencies. -.. _`virtualenv`: https://virtualenv.pypa.io/en/latest/ +.. _`venv`: https://docs.python.org/3/library/venv.html Code samples and snippets ~~~~~~~~~~~~~~~~~~~~~~~~~ -Code samples and snippets live in the `samples/` folder. +Code samples and snippets live in the `samples/`_ folder. + +.. _samples/: https://github.com/googleapis/google-cloud-python/tree/main/packages/google-ai-generativelanguage/samples Supported Python Versions @@ -77,10 +80,9 @@ Mac/Linux .. code-block:: console - pip install virtualenv - virtualenv + python3 -m venv source /bin/activate - /bin/pip install google-ai-generativelanguage + pip install google-ai-generativelanguage Windows @@ -88,10 +90,9 @@ Windows .. code-block:: console - pip install virtualenv - virtualenv - \Scripts\activate - \Scripts\pip.exe install google-ai-generativelanguage + py -m venv + .\\Scripts\activate + pip install google-ai-generativelanguage Next Steps ~~~~~~~~~~ diff --git a/packages/google-ai-generativelanguage/docs/conf.py b/packages/google-ai-generativelanguage/docs/conf.py index c865aa35cd19..a7f886e7207c 100644 --- a/packages/google-ai-generativelanguage/docs/conf.py +++ b/packages/google-ai-generativelanguage/docs/conf.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2021 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/discuss_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/discuss_service.rst similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/discuss_service.rst rename to packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/discuss_service.rst diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/model_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/model_service.rst similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/model_service.rst rename to packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/model_service.rst diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/permission_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/permission_service.rst similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/permission_service.rst rename to packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/permission_service.rst diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/services.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/services.rst similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/services.rst rename to packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/services.rst diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/text_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/text_service.rst similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/text_service.rst rename to packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/text_service.rst diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/types.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/types.rst similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/docs/generativelanguage_v1beta3/types.rst rename to packages/google-ai-generativelanguage/docs/generativelanguage_v1beta3/types.rst diff --git a/packages/google-ai-generativelanguage/docs/index.rst b/packages/google-ai-generativelanguage/docs/index.rst index 8b2a9dd36744..4c7f05ce43b3 100644 --- a/packages/google-ai-generativelanguage/docs/index.rst +++ b/packages/google-ai-generativelanguage/docs/index.rst @@ -2,6 +2,9 @@ .. include:: multiprocessing.rst +This package includes clients for multiple versions of Generative Language API. +By default, you will get version ``generativelanguage_v1beta2``. + API Reference ------------- @@ -11,6 +14,14 @@ API Reference generativelanguage_v1beta2/services generativelanguage_v1beta2/types +API Reference +------------- +.. toctree:: + :maxdepth: 2 + + generativelanguage_v1beta3/services + generativelanguage_v1beta3/types + Changelog --------- diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py index 89021a05f793..076ca0259c32 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/async_client.py @@ -341,6 +341,7 @@ async def sample_generate_message(): Returns: google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse: The response from the model. + This includes candidate messages and conversation history in the form of chronologically-ordered messages. diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py index 6715c8a7827d..8d529edeedb1 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/client.py @@ -563,6 +563,7 @@ def sample_generate_message(): Returns: google.ai.generativelanguage_v1beta2.types.GenerateMessageResponse: The response from the model. + This includes candidate messages and conversation history in the form of chronologically-ordered messages. diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py index 0165c8137ff3..00bb295d38cd 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/rest.py @@ -373,6 +373,7 @@ def __call__( Returns: ~.discuss_service.GenerateMessageResponse: The response from the model. + This includes candidate messages and conversation history in the form of chronologically-ordered messages. diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/discuss_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/discuss_service.py index f926726bc772..e4df21d77f91 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/discuss_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/discuss_service.py @@ -123,6 +123,7 @@ class GenerateMessageRequest(proto.Message): class GenerateMessageResponse(proto.Message): r"""The response from the model. + This includes candidate messages and conversation history in the form of chronologically-ordered messages. @@ -174,9 +175,11 @@ class Message(proto.Message): Attributes: author (str): Optional. The author of this Message. + This serves as a key for tagging the content of this Message when it is fed to the model as text. + The author can be any alphanumeric string. content (str): Required. The text content of the structured ``Message``. @@ -276,6 +279,7 @@ class MessagePrompt(proto.Message): class Example(proto.Message): r"""An input/output example used to instruct the Model. + It demonstrates how the model should respond or format its response. diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/model.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/model.py index da5ac7fa0fa3..2cd866f80f0d 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/model.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/model.py @@ -53,6 +53,7 @@ class Model(proto.Message): - ``chat-bison`` version (str): Required. The version number of the model. + This represents the major version display_name (str): The human-readable name of the model. E.g. diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/safety.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/safety.py index 099bd5cfbb20..5584347b1886 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/safety.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/safety.py @@ -33,6 +33,7 @@ class HarmCategory(proto.Enum): r"""The category of a rating. + These categories cover various kinds of harms that developers may wish to adjust. @@ -117,6 +118,7 @@ class BlockedReason(proto.Enum): class SafetyFeedback(proto.Message): r"""Safety feedback for an entire request. + This field is populated if content in the input and/or response is blocked due to safety settings. SafetyFeedback may not exist for every HarmCategory. Each SafetyFeedback will return the @@ -145,6 +147,7 @@ class SafetyFeedback(proto.Message): class SafetyRating(proto.Message): r"""Safety rating for a piece of content. + The safety rating contains the category of harm and the harm probability level in that category for a piece of content. Content is classified for safety across a number of harm @@ -161,6 +164,7 @@ class SafetyRating(proto.Message): class HarmProbability(proto.Enum): r"""The probability that a piece of content is harmful. + The classification system gives the probability of the content being unsafe. This does not indicate the severity of harm for a piece of content. @@ -198,6 +202,7 @@ class HarmProbability(proto.Enum): class SafetySetting(proto.Message): r"""Safety setting, affecting the safety-blocking behavior. + Passing a safety setting for a category changes the allowed proability that content is blocked. diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/text_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/text_service.py index 5227db6bd5f4..4e1797ebc28a 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/text_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/types/text_service.py @@ -216,6 +216,7 @@ class GenerateTextResponse(proto.Message): class TextPrompt(proto.Message): r"""Text given to the model as a prompt. + The Model will use this TextPrompt to Generate a text completion. @@ -241,6 +242,7 @@ class TextCompletion(proto.Message): the model. safety_ratings (MutableSequence[google.ai.generativelanguage_v1beta2.types.SafetyRating]): Ratings for the safety of a response. + There is at most one rating per category. citation_metadata (google.ai.generativelanguage_v1beta2.types.CitationMetadata): Output only. Citation information for model-generated diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/__init__.py new file mode 100644 index 000000000000..f8746750bee7 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/__init__.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version + +__version__ = package_version.__version__ + + +from .services.discuss_service import DiscussServiceAsyncClient, DiscussServiceClient +from .services.model_service import ModelServiceAsyncClient, ModelServiceClient +from .services.permission_service import ( + PermissionServiceAsyncClient, + PermissionServiceClient, +) +from .services.text_service import TextServiceAsyncClient, TextServiceClient +from .types.citation import CitationMetadata, CitationSource +from .types.discuss_service import ( + CountMessageTokensRequest, + CountMessageTokensResponse, + Example, + GenerateMessageRequest, + GenerateMessageResponse, + Message, + MessagePrompt, +) +from .types.model import Model +from .types.model_service import ( + CreateTunedModelMetadata, + CreateTunedModelRequest, + DeleteTunedModelRequest, + GetModelRequest, + GetTunedModelRequest, + ListModelsRequest, + ListModelsResponse, + ListTunedModelsRequest, + ListTunedModelsResponse, + UpdateTunedModelRequest, +) +from .types.permission import Permission +from .types.permission_service import ( + CreatePermissionRequest, + DeletePermissionRequest, + GetPermissionRequest, + ListPermissionsRequest, + ListPermissionsResponse, + TransferOwnershipRequest, + TransferOwnershipResponse, + UpdatePermissionRequest, +) +from .types.safety import ( + ContentFilter, + HarmCategory, + SafetyFeedback, + SafetyRating, + SafetySetting, +) +from .types.text_service import ( + BatchEmbedTextRequest, + BatchEmbedTextResponse, + CountTextTokensRequest, + CountTextTokensResponse, + Embedding, + EmbedTextRequest, + EmbedTextResponse, + GenerateTextRequest, + GenerateTextResponse, + TextCompletion, + TextPrompt, +) +from .types.tuned_model import ( + Dataset, + Hyperparameters, + TunedModel, + TunedModelSource, + TuningExample, + TuningExamples, + TuningSnapshot, + TuningTask, +) + +__all__ = ( + "DiscussServiceAsyncClient", + "ModelServiceAsyncClient", + "PermissionServiceAsyncClient", + "TextServiceAsyncClient", + "BatchEmbedTextRequest", + "BatchEmbedTextResponse", + "CitationMetadata", + "CitationSource", + "ContentFilter", + "CountMessageTokensRequest", + "CountMessageTokensResponse", + "CountTextTokensRequest", + "CountTextTokensResponse", + "CreatePermissionRequest", + "CreateTunedModelMetadata", + "CreateTunedModelRequest", + "Dataset", + "DeletePermissionRequest", + "DeleteTunedModelRequest", + "DiscussServiceClient", + "EmbedTextRequest", + "EmbedTextResponse", + "Embedding", + "Example", + "GenerateMessageRequest", + "GenerateMessageResponse", + "GenerateTextRequest", + "GenerateTextResponse", + "GetModelRequest", + "GetPermissionRequest", + "GetTunedModelRequest", + "HarmCategory", + "Hyperparameters", + "ListModelsRequest", + "ListModelsResponse", + "ListPermissionsRequest", + "ListPermissionsResponse", + "ListTunedModelsRequest", + "ListTunedModelsResponse", + "Message", + "MessagePrompt", + "Model", + "ModelServiceClient", + "Permission", + "PermissionServiceClient", + "SafetyFeedback", + "SafetyRating", + "SafetySetting", + "TextCompletion", + "TextPrompt", + "TextServiceClient", + "TransferOwnershipRequest", + "TransferOwnershipResponse", + "TunedModel", + "TunedModelSource", + "TuningExample", + "TuningExamples", + "TuningSnapshot", + "TuningTask", + "UpdatePermissionRequest", + "UpdateTunedModelRequest", +) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_metadata.json b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_metadata.json similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/gapic_metadata.json rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_metadata.json diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/gapic_version.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_version.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/gapic_version.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_version.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/py.typed b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/py.typed similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage/py.typed rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/py.typed diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/__init__.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/__init__.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py similarity index 92% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py index c5c6e8208269..2247026798d5 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/__init__.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .client import DiscussServiceClient from .async_client import DiscussServiceAsyncClient +from .client import DiscussServiceClient __all__ = ( - 'DiscussServiceClient', - 'DiscussServiceAsyncClient', + "DiscussServiceClient", + "DiscussServiceAsyncClient", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py similarity index 83% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py index 1f9cde10aa21..e2b522b2d84e 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/async_client.py @@ -16,28 +16,39 @@ from collections import OrderedDict import functools import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) -from google.api_core.client_options import ClientOptions from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import discuss_service -from google.ai.generativelanguage_v1beta3.types import safety -from google.longrunning import operations_pb2 # type: ignore -from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1beta3.types import discuss_service, safety + from .client import DiscussServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, DiscussServiceTransport +from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport class DiscussServiceAsyncClient: @@ -54,16 +65,30 @@ class DiscussServiceAsyncClient: model_path = staticmethod(DiscussServiceClient.model_path) parse_model_path = staticmethod(DiscussServiceClient.parse_model_path) - common_billing_account_path = staticmethod(DiscussServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(DiscussServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + DiscussServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + DiscussServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(DiscussServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(DiscussServiceClient.parse_common_folder_path) - common_organization_path = staticmethod(DiscussServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(DiscussServiceClient.parse_common_organization_path) + parse_common_folder_path = staticmethod( + DiscussServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + DiscussServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + DiscussServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(DiscussServiceClient.common_project_path) - parse_common_project_path = staticmethod(DiscussServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + DiscussServiceClient.parse_common_project_path + ) common_location_path = staticmethod(DiscussServiceClient.common_location_path) - parse_common_location_path = staticmethod(DiscussServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + DiscussServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -99,7 +124,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -141,14 +168,18 @@ def transport(self) -> DiscussServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(DiscussServiceClient).get_transport_class, type(DiscussServiceClient)) - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, DiscussServiceTransport] = "grpc_asyncio", - client_options: Optional[ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + get_transport_class = functools.partial( + type(DiscussServiceClient).get_transport_class, type(DiscussServiceClient) + ) + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, DiscussServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the discuss service client. Args: @@ -186,22 +217,22 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def generate_message(self, - request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.GenerateMessageResponse: + async def generate_message( + self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.GenerateMessageResponse: r"""Generates a response from the model given an input ``MessagePrompt``. @@ -321,10 +352,14 @@ async def sample_generate_message(): # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) + has_flattened_params = any( + [model, prompt, temperature, candidate_count, top_p, top_k] + ) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = discuss_service.GenerateMessageRequest(request) @@ -354,9 +389,7 @@ async def sample_generate_message(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -370,15 +403,18 @@ async def sample_generate_message(): # Done; return the response. return response - async def count_message_tokens(self, - request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.CountMessageTokensResponse: + async def count_message_tokens( + self, + request: Optional[ + Union[discuss_service.CountMessageTokensRequest, dict] + ] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.CountMessageTokensResponse: r"""Runs a model's tokenizer on a string and returns the token count. @@ -456,8 +492,10 @@ async def sample_count_message_tokens(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, prompt]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = discuss_service.CountMessageTokensRequest(request) @@ -479,9 +517,7 @@ async def sample_count_message_tokens(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -501,9 +537,10 @@ async def __aenter__(self) -> "DiscussServiceAsyncClient": async def __aexit__(self, exc_type, exc, tb): await self.transport.close() -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - -__all__ = ( - "DiscussServiceAsyncClient", +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ ) + + +__all__ = ("DiscussServiceAsyncClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py similarity index 82% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py index 1e3b5952d0bf..8f22179ad92c 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/client.py @@ -16,29 +16,41 @@ from collections import OrderedDict import os import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from google.api_core import client_options as client_options_lib from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import discuss_service -from google.ai.generativelanguage_v1beta3.types import safety -from google.longrunning import operations_pb2 # type: ignore -from .transports.base import DiscussServiceTransport, DEFAULT_CLIENT_INFO +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1beta3.types import discuss_service, safety + +from .transports.base import DEFAULT_CLIENT_INFO, DiscussServiceTransport from .transports.grpc import DiscussServiceGrpcTransport from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport from .transports.rest import DiscussServiceRestTransport @@ -51,14 +63,18 @@ class DiscussServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DiscussServiceTransport]] _transport_registry["grpc"] = DiscussServiceGrpcTransport _transport_registry["grpc_asyncio"] = DiscussServiceGrpcAsyncIOTransport _transport_registry["rest"] = DiscussServiceRestTransport - def get_transport_class(cls, - label: Optional[str] = None, - ) -> Type[DiscussServiceTransport]: + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[DiscussServiceTransport]: """Returns an appropriate transport class. Args: @@ -150,8 +166,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: DiscussServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) + credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials return cls(*args, **kwargs) @@ -168,73 +183,101 @@ def transport(self) -> DiscussServiceTransport: return self._transport @staticmethod - def model_path(model: str,) -> str: + def model_path( + model: str, + ) -> str: """Returns a fully-qualified model string.""" - return "models/{model}".format(model=model, ) + return "models/{model}".format( + model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parses a model path into its component segments.""" m = re.match(r"^models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path( + billing_account: str, + ) -> str: """Returns a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path( + folder: str, + ) -> str: """Returns a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format( + folder=folder, + ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path( + organization: str, + ) -> str: """Returns a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format( + organization=organization, + ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path( + project: str, + ) -> str: """Returns a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format( + project=project, + ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path( + project: str, + location: str, + ) -> str: """Returns a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -270,9 +313,13 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_client_cert not in ("true", "false"): - raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) if use_mtls_endpoint not in ("auto", "never", "always"): - raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Figure out the client cert source to use. client_cert_source = None @@ -285,19 +332,23 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio # Figure out which api endpoint to use. if client_options.api_endpoint is not None: api_endpoint = client_options.api_endpoint - elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): api_endpoint = cls.DEFAULT_MTLS_ENDPOINT else: api_endpoint = cls.DEFAULT_ENDPOINT return api_endpoint, client_cert_source - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, DiscussServiceTransport]] = None, - client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, DiscussServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the discuss service client. Args: @@ -341,11 +392,15 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() client_options = cast(client_options_lib.ClientOptions, client_options) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options + ) api_key_value = getattr(client_options, "api_key", None) if api_key_value and credentials: - raise ValueError("client_options.api_key and credentials are mutually exclusive") + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -353,8 +408,10 @@ def __init__(self, *, if isinstance(transport, DiscussServiceTransport): # transport is a DiscussServiceTransport instance. if credentials or client_options.credentials_file or api_key_value: - raise ValueError("When providing a transport instance, " - "provide its credentials directly.") + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " @@ -364,8 +421,12 @@ def __init__(self, *, else: import google.auth._default # type: ignore - if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): - credentials = google.auth._default.get_api_key_credentials(api_key_value) + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) Transport = type(self).get_transport_class(transport) self._transport = Transport( @@ -380,19 +441,20 @@ def __init__(self, *, api_audience=client_options.api_audience, ) - def generate_message(self, - request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.GenerateMessageResponse: + def generate_message( + self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.GenerateMessageResponse: r"""Generates a response from the model given an input ``MessagePrompt``. @@ -512,10 +574,14 @@ def sample_generate_message(): # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, top_p, top_k]) + has_flattened_params = any( + [model, prompt, temperature, candidate_count, top_p, top_k] + ) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a discuss_service.GenerateMessageRequest. @@ -542,12 +608,10 @@ def sample_generate_message(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.generate_message] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -561,15 +625,18 @@ def sample_generate_message(): # Done; return the response. return response - def count_message_tokens(self, - request: Optional[Union[discuss_service.CountMessageTokensRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[discuss_service.MessagePrompt] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> discuss_service.CountMessageTokensResponse: + def count_message_tokens( + self, + request: Optional[ + Union[discuss_service.CountMessageTokensRequest, dict] + ] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.CountMessageTokensResponse: r"""Runs a model's tokenizer on a string and returns the token count. @@ -647,8 +714,10 @@ def sample_count_message_tokens(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, prompt]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a discuss_service.CountMessageTokensRequest. @@ -667,12 +736,10 @@ def sample_count_message_tokens(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.count_message_tokens] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -700,18 +767,9 @@ def __exit__(self, type, value, traceback): self.transport.close() +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) - - - - - - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "DiscussServiceClient", -) +__all__ = ("DiscussServiceClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py similarity index 67% rename from owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py index b585c1ce424c..209ce4db6d6e 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/discuss_service/transports/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/__init__.py @@ -19,20 +19,18 @@ from .base import DiscussServiceTransport from .grpc import DiscussServiceGrpcTransport from .grpc_asyncio import DiscussServiceGrpcAsyncIOTransport -from .rest import DiscussServiceRestTransport -from .rest import DiscussServiceRestInterceptor - +from .rest import DiscussServiceRestInterceptor, DiscussServiceRestTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] -_transport_registry['grpc'] = DiscussServiceGrpcTransport -_transport_registry['grpc_asyncio'] = DiscussServiceGrpcAsyncIOTransport -_transport_registry['rest'] = DiscussServiceRestTransport +_transport_registry["grpc"] = DiscussServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DiscussServiceGrpcAsyncIOTransport +_transport_registry["rest"] = DiscussServiceRestTransport __all__ = ( - 'DiscussServiceTransport', - 'DiscussServiceGrpcTransport', - 'DiscussServiceGrpcAsyncIOTransport', - 'DiscussServiceRestTransport', - 'DiscussServiceRestInterceptor', + "DiscussServiceTransport", + "DiscussServiceGrpcTransport", + "DiscussServiceGrpcAsyncIOTransport", + "DiscussServiceRestTransport", + "DiscussServiceRestInterceptor", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py similarity index 65% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py index 7c455e9f245e..c2847b9722ae 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/base.py @@ -16,41 +16,43 @@ import abc from typing import Awaitable, Callable, Dict, Optional, Sequence, Union -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version - -import google.auth # type: ignore import google.api_core from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version from google.ai.generativelanguage_v1beta3.types import discuss_service -from google.longrunning import operations_pb2 # type: ignore -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) class DiscussServiceTransport(abc.ABC): """Abstract transport class for DiscussService.""" - AUTH_SCOPES = ( - ) + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" - DEFAULT_HOST: str = 'generativelanguage.googleapis.com' def __init__( - self, *, - host: str = DEFAULT_HOST, - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - **kwargs, - ) -> None: + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -84,30 +86,38 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = google.auth.load_credentials_from_file( - credentials_file, - **scopes_kwargs, - quota_project_id=quota_project_id - ) + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) # Don't apply audience if the credentials file passed from user. if hasattr(credentials, "with_gdch_audience"): - credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) # If the credentials are service account credentials, then always try to use self signed JWT. - if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): credentials = credentials.with_always_use_jwt_access(True) # Save the credentials. self._credentials = credentials # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host def _prep_wrapped_messages(self, client_info): @@ -123,33 +133,39 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } + } def close(self): """Closes resources associated with the transport. - .. warning:: - Only call this method if the transport is NOT shared - with other clients - this may cause errors in other clients! + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! """ raise NotImplementedError() @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - Union[ - discuss_service.GenerateMessageResponse, - Awaitable[discuss_service.GenerateMessageResponse] - ]]: + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + Union[ + discuss_service.GenerateMessageResponse, + Awaitable[discuss_service.GenerateMessageResponse], + ], + ]: raise NotImplementedError() @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - Union[ - discuss_service.CountMessageTokensResponse, - Awaitable[discuss_service.CountMessageTokensResponse] - ]]: + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Union[ + discuss_service.CountMessageTokensResponse, + Awaitable[discuss_service.CountMessageTokensResponse], + ], + ]: raise NotImplementedError() @property @@ -157,6 +173,4 @@ def kind(self) -> str: raise NotImplementedError() -__all__ = ( - 'DiscussServiceTransport', -) +__all__ = ("DiscussServiceTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py similarity index 81% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py index 3e6abae06b98..c0032918eac1 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc.py @@ -13,20 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import grpc_helpers -from google.api_core import gapic_v1 -import google.auth # type: ignore +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - +from google.longrunning import operations_pb2 # type: ignore import grpc # type: ignore from google.ai.generativelanguage_v1beta3.types import discuss_service -from google.longrunning import operations_pb2 # type: ignore -from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO + +from .base import DEFAULT_CLIENT_INFO, DiscussServiceTransport class DiscussServiceGrpcTransport(DiscussServiceTransport): @@ -44,23 +43,26 @@ class DiscussServiceGrpcTransport(DiscussServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -179,13 +181,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -220,19 +224,21 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - discuss_service.GenerateMessageResponse]: + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse, + ]: r"""Return a callable for the generate message method over gRPC. Generates a response from the model given an input @@ -248,18 +254,21 @@ def generate_message(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'generate_message' not in self._stubs: - self._stubs['generate_message'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.DiscussService/GenerateMessage', + if "generate_message" not in self._stubs: + self._stubs["generate_message"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.DiscussService/GenerateMessage", request_serializer=discuss_service.GenerateMessageRequest.serialize, response_deserializer=discuss_service.GenerateMessageResponse.deserialize, ) - return self._stubs['generate_message'] + return self._stubs["generate_message"] @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - discuss_service.CountMessageTokensResponse]: + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse, + ]: r"""Return a callable for the count message tokens method over gRPC. Runs a model's tokenizer on a string and returns the @@ -275,13 +284,13 @@ def count_message_tokens(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'count_message_tokens' not in self._stubs: - self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.DiscussService/CountMessageTokens', + if "count_message_tokens" not in self._stubs: + self._stubs["count_message_tokens"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.DiscussService/CountMessageTokens", request_serializer=discuss_service.CountMessageTokensRequest.serialize, response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, ) - return self._stubs['count_message_tokens'] + return self._stubs["count_message_tokens"] def close(self): self.grpc_channel.close() @@ -291,6 +300,4 @@ def kind(self) -> str: return "grpc" -__all__ = ( - 'DiscussServiceGrpcTransport', -) +__all__ = ("DiscussServiceGrpcTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py similarity index 81% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py index 48e36ab4d7ad..f465e0191147 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/grpc_asyncio.py @@ -13,20 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import gapic_v1, grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore +from google.longrunning import operations_pb2 # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.ai.generativelanguage_v1beta3.types import discuss_service -from google.longrunning import operations_pb2 # type: ignore -from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO + +from .base import DEFAULT_CLIENT_INFO, DiscussServiceTransport from .grpc import DiscussServiceGrpcTransport @@ -50,13 +49,15 @@ class DiscussServiceGrpcAsyncIOTransport(DiscussServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -87,24 +88,26 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -233,9 +236,12 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - Awaitable[discuss_service.GenerateMessageResponse]]: + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + Awaitable[discuss_service.GenerateMessageResponse], + ]: r"""Return a callable for the generate message method over gRPC. Generates a response from the model given an input @@ -251,18 +257,21 @@ def generate_message(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'generate_message' not in self._stubs: - self._stubs['generate_message'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.DiscussService/GenerateMessage', + if "generate_message" not in self._stubs: + self._stubs["generate_message"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.DiscussService/GenerateMessage", request_serializer=discuss_service.GenerateMessageRequest.serialize, response_deserializer=discuss_service.GenerateMessageResponse.deserialize, ) - return self._stubs['generate_message'] + return self._stubs["generate_message"] @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - Awaitable[discuss_service.CountMessageTokensResponse]]: + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Awaitable[discuss_service.CountMessageTokensResponse], + ]: r"""Return a callable for the count message tokens method over gRPC. Runs a model's tokenizer on a string and returns the @@ -278,18 +287,16 @@ def count_message_tokens(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'count_message_tokens' not in self._stubs: - self._stubs['count_message_tokens'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.DiscussService/CountMessageTokens', + if "count_message_tokens" not in self._stubs: + self._stubs["count_message_tokens"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.DiscussService/CountMessageTokens", request_serializer=discuss_service.CountMessageTokensRequest.serialize, response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, ) - return self._stubs['count_message_tokens'] + return self._stubs["count_message_tokens"] def close(self): return self.grpc_channel.close() -__all__ = ( - 'DiscussServiceGrpcAsyncIOTransport', -) +__all__ = ("DiscussServiceGrpcAsyncIOTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py similarity index 73% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py index 0585ca398116..1df043b1a604 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/discuss_service/transports/rest.py @@ -14,24 +14,21 @@ # limitations under the License. # -from google.auth.transport.requests import AuthorizedSession # type: ignore +import dataclasses import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, path_template, rest_helpers, rest_streaming from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -from google.api_core import rest_helpers -from google.api_core import rest_streaming -from google.api_core import path_template -from google.api_core import gapic_v1 - +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore from google.protobuf import json_format +import grpc # type: ignore from requests import __version__ as requests_version -import dataclasses -import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import warnings try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] @@ -39,11 +36,12 @@ OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import discuss_service from google.longrunning import operations_pb2 # type: ignore -from .base import DiscussServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from google.ai.generativelanguage_v1beta3.types import discuss_service +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .base import DiscussServiceTransport DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, @@ -88,7 +86,12 @@ def post_generate_message(self, response): """ - def pre_count_message_tokens(self, request: discuss_service.CountMessageTokensRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.CountMessageTokensRequest, Sequence[Tuple[str, str]]]: + + def pre_count_message_tokens( + self, + request: discuss_service.CountMessageTokensRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[discuss_service.CountMessageTokensRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for count_message_tokens Override in a subclass to manipulate the request or metadata @@ -96,7 +99,9 @@ def pre_count_message_tokens(self, request: discuss_service.CountMessageTokensRe """ return request, metadata - def post_count_message_tokens(self, response: discuss_service.CountMessageTokensResponse) -> discuss_service.CountMessageTokensResponse: + def post_count_message_tokens( + self, response: discuss_service.CountMessageTokensResponse + ) -> discuss_service.CountMessageTokensResponse: """Post-rpc interceptor for count_message_tokens Override in a subclass to manipulate the response @@ -104,7 +109,12 @@ def post_count_message_tokens(self, response: discuss_service.CountMessageTokens it is returned to user code. """ return response - def pre_generate_message(self, request: discuss_service.GenerateMessageRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[discuss_service.GenerateMessageRequest, Sequence[Tuple[str, str]]]: + + def pre_generate_message( + self, + request: discuss_service.GenerateMessageRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[discuss_service.GenerateMessageRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for generate_message Override in a subclass to manipulate the request or metadata @@ -112,7 +122,9 @@ def pre_generate_message(self, request: discuss_service.GenerateMessageRequest, """ return request, metadata - def post_generate_message(self, response: discuss_service.GenerateMessageResponse) -> discuss_service.GenerateMessageResponse: + def post_generate_message( + self, response: discuss_service.GenerateMessageResponse + ) -> discuss_service.GenerateMessageResponse: """Post-rpc interceptor for generate_message Override in a subclass to manipulate the response @@ -145,20 +157,21 @@ class DiscussServiceRestTransport(DiscussServiceTransport): """ - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - client_cert_source_for_mtls: Optional[Callable[[ - ], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - url_scheme: str = 'https', - interceptor: Optional[DiscussServiceRestInterceptor] = None, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[DiscussServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -197,7 +210,9 @@ def __init__(self, *, # credentials object maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) if maybe_url_match is None: - raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER url_match_items = maybe_url_match.groupdict() @@ -208,10 +223,11 @@ def __init__(self, *, credentials=credentials, client_info=client_info, always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience + api_audience=api_audience, ) self._session = AuthorizedSession( - self._credentials, default_host=self.DEFAULT_HOST) + self._credentials, default_host=self.DEFAULT_HOST + ) if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) self._interceptor = interceptor or DiscussServiceRestInterceptor() @@ -221,19 +237,24 @@ class _CountMessageTokens(DiscussServiceRestStub): def __hash__(self): return hash("CountMessageTokens") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: discuss_service.CountMessageTokensRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> discuss_service.CountMessageTokensResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: discuss_service.CountMessageTokensRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.CountMessageTokensResponse: r"""Call the count message tokens method over HTTP. Args: @@ -258,46 +279,51 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{model=models/*}:countMessageTokens', - 'body': '*', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{model=models/*}:countMessageTokens", + "body": "*", + }, ] - request, metadata = self._interceptor.pre_count_message_tokens(request, metadata) + request, metadata = self._interceptor.pre_count_message_tokens( + request, metadata + ) pb_request = discuss_service.CountMessageTokensRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -316,19 +342,24 @@ class _GenerateMessage(DiscussServiceRestStub): def __hash__(self): return hash("GenerateMessage") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: discuss_service.GenerateMessageRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> discuss_service.GenerateMessageResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: discuss_service.GenerateMessageRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> discuss_service.GenerateMessageResponse: r"""Call the generate message method over HTTP. Args: @@ -351,46 +382,51 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{model=models/*}:generateMessage', - 'body': '*', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{model=models/*}:generateMessage", + "body": "*", + }, ] - request, metadata = self._interceptor.pre_generate_message(request, metadata) + request, metadata = self._interceptor.pre_generate_message( + request, metadata + ) pb_request = discuss_service.GenerateMessageRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -406,20 +442,26 @@ def __call__(self, return resp @property - def count_message_tokens(self) -> Callable[ - [discuss_service.CountMessageTokensRequest], - discuss_service.CountMessageTokensResponse]: + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse, + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._CountMessageTokens(self._session, self._host, self._interceptor) # type: ignore + return self._CountMessageTokens(self._session, self._host, self._interceptor) # type: ignore @property - def generate_message(self) -> Callable[ - [discuss_service.GenerateMessageRequest], - discuss_service.GenerateMessageResponse]: + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse, + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._GenerateMessage(self._session, self._host, self._interceptor) # type: ignore + return self._GenerateMessage(self._session, self._host, self._interceptor) # type: ignore @property def kind(self) -> str: @@ -429,6 +471,4 @@ def close(self): self._session.close() -__all__=( - 'DiscussServiceRestTransport', -) +__all__ = ("DiscussServiceRestTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py similarity index 92% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py index 2c368b92d844..5738b8bf4239 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/__init__.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .client import ModelServiceClient from .async_client import ModelServiceAsyncClient +from .client import ModelServiceClient __all__ = ( - 'ModelServiceClient', - 'ModelServiceAsyncClient', + "ModelServiceClient", + "ModelServiceAsyncClient", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py similarity index 85% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py index 0759ce5d1845..3257f2817e7e 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/async_client.py @@ -16,35 +16,46 @@ from collections import OrderedDict import functools import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) -from google.api_core.client_options import ClientOptions from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.services.model_service import pagers -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model -from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore -from google.longrunning import operations_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore -from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport + +from google.ai.generativelanguage_v1beta3.services.model_service import pagers +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1beta3.types import model, model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model + from .client import ModelServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, ModelServiceTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport class ModelServiceAsyncClient: @@ -61,16 +72,26 @@ class ModelServiceAsyncClient: parse_model_path = staticmethod(ModelServiceClient.parse_model_path) tuned_model_path = staticmethod(ModelServiceClient.tuned_model_path) parse_tuned_model_path = staticmethod(ModelServiceClient.parse_tuned_model_path) - common_billing_account_path = staticmethod(ModelServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(ModelServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + ModelServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(ModelServiceClient.common_folder_path) parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) common_organization_path = staticmethod(ModelServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(ModelServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + ModelServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(ModelServiceClient.common_project_path) - parse_common_project_path = staticmethod(ModelServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + ModelServiceClient.parse_common_project_path + ) common_location_path = staticmethod(ModelServiceClient.common_location_path) - parse_common_location_path = staticmethod(ModelServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + ModelServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -106,7 +127,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -148,14 +171,18 @@ def transport(self) -> ModelServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(ModelServiceClient).get_transport_class, type(ModelServiceClient)) - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, ModelServiceTransport] = "grpc_asyncio", - client_options: Optional[ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + get_transport_class = functools.partial( + type(ModelServiceClient).get_transport_class, type(ModelServiceClient) + ) + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the model service client. Args: @@ -193,17 +220,17 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def get_model(self, - request: Optional[Union[model_service.GetModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + async def get_model( + self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets information about a specific Model. .. code-block:: python @@ -264,8 +291,10 @@ async def sample_get_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetModelRequest(request) @@ -285,9 +314,7 @@ async def sample_get_model(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -301,15 +328,16 @@ async def sample_get_model(): # Done; return the response. return response - async def list_models(self, - request: Optional[Union[model_service.ListModelsRequest, dict]] = None, - *, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsAsyncPager: + async def list_models( + self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: r"""Lists models available through the API. .. code-block:: python @@ -386,8 +414,10 @@ async def sample_list_models(): # gotten any keyword arguments that map to the request. has_flattened_params = any([page_size, page_token]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListModelsRequest(request) @@ -426,14 +456,15 @@ async def sample_list_models(): # Done; return the response. return response - async def get_tuned_model(self, - request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> tuned_model.TunedModel: + async def get_tuned_model( + self, + request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tuned_model.TunedModel: r"""Gets information about a specific TunedModel. .. code-block:: python @@ -491,8 +522,10 @@ async def sample_get_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.GetTunedModelRequest(request) @@ -512,9 +545,7 @@ async def sample_get_tuned_model(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -528,15 +559,16 @@ async def sample_get_tuned_model(): # Done; return the response. return response - async def list_tuned_models(self, - request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, - *, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTunedModelsAsyncPager: + async def list_tuned_models( + self, + request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTunedModelsAsyncPager: r"""Lists tuned models owned by the user. .. code-block:: python @@ -614,8 +646,10 @@ async def sample_list_tuned_models(): # gotten any keyword arguments that map to the request. has_flattened_params = any([page_size, page_token]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.ListTunedModelsRequest(request) @@ -654,15 +688,16 @@ async def sample_list_tuned_models(): # Done; return the response. return response - async def create_tuned_model(self, - request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, - *, - tuned_model: Optional[gag_tuned_model.TunedModel] = None, - tuned_model_id: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation_async.AsyncOperation: + async def create_tuned_model( + self, + request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + tuned_model_id: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: r"""Creates a tuned model. Intermediate tuning progress (if any) is accessed through the [google.longrunning.Operations] service. @@ -743,8 +778,10 @@ async def sample_create_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([tuned_model, tuned_model_id]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.CreateTunedModelRequest(request) @@ -782,15 +819,16 @@ async def sample_create_tuned_model(): # Done; return the response. return response - async def update_tuned_model(self, - request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, - *, - tuned_model: Optional[gag_tuned_model.TunedModel] = None, - update_mask: Optional[field_mask_pb2.FieldMask] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gag_tuned_model.TunedModel: + async def update_tuned_model( + self, + request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_tuned_model.TunedModel: r"""Updates a tuned model. .. code-block:: python @@ -855,8 +893,10 @@ async def sample_update_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([tuned_model, update_mask]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.UpdateTunedModelRequest(request) @@ -878,9 +918,9 @@ async def sample_update_tuned_model(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("tuned_model.name", request.tuned_model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("tuned_model.name", request.tuned_model.name),) + ), ) # Send the request. @@ -894,14 +934,15 @@ async def sample_update_tuned_model(): # Done; return the response. return response - async def delete_tuned_model(self, - request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def delete_tuned_model( + self, + request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a tuned model. .. code-block:: python @@ -948,8 +989,10 @@ async def sample_delete_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = model_service.DeleteTunedModelRequest(request) @@ -969,9 +1012,7 @@ async def sample_delete_tuned_model(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -988,9 +1029,10 @@ async def __aenter__(self) -> "ModelServiceAsyncClient": async def __aexit__(self, exc_type, exc, tb): await self.transport.close() -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - -__all__ = ( - "ModelServiceAsyncClient", +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ ) + + +__all__ = ("ModelServiceAsyncClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/client.py similarity index 83% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/client.py index d64fa37b6f6f..357ecd294e45 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/client.py @@ -16,36 +16,48 @@ from collections import OrderedDict import os import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from google.api_core import client_options as client_options_lib from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.services.model_service import pagers -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model -from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model from google.api_core import operation # type: ignore from google.api_core import operation_async # type: ignore -from google.longrunning import operations_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore -from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO + +from google.ai.generativelanguage_v1beta3.services.model_service import pagers +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1beta3.types import model, model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model + +from .transports.base import DEFAULT_CLIENT_INFO, ModelServiceTransport from .transports.grpc import ModelServiceGrpcTransport from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport from .transports.rest import ModelServiceRestTransport @@ -58,14 +70,16 @@ class ModelServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] _transport_registry["grpc"] = ModelServiceGrpcTransport _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport _transport_registry["rest"] = ModelServiceRestTransport - def get_transport_class(cls, - label: Optional[str] = None, - ) -> Type[ModelServiceTransport]: + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[ModelServiceTransport]: """Returns an appropriate transport class. Args: @@ -155,8 +169,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: ModelServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) + credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials return cls(*args, **kwargs) @@ -173,84 +186,116 @@ def transport(self) -> ModelServiceTransport: return self._transport @staticmethod - def model_path(model: str,) -> str: + def model_path( + model: str, + ) -> str: """Returns a fully-qualified model string.""" - return "models/{model}".format(model=model, ) + return "models/{model}".format( + model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parses a model path into its component segments.""" m = re.match(r"^models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def tuned_model_path(tuned_model: str,) -> str: + def tuned_model_path( + tuned_model: str, + ) -> str: """Returns a fully-qualified tuned_model string.""" - return "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + return "tunedModels/{tuned_model}".format( + tuned_model=tuned_model, + ) @staticmethod - def parse_tuned_model_path(path: str) -> Dict[str,str]: + def parse_tuned_model_path(path: str) -> Dict[str, str]: """Parses a tuned_model path into its component segments.""" m = re.match(r"^tunedModels/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path( + billing_account: str, + ) -> str: """Returns a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path( + folder: str, + ) -> str: """Returns a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format( + folder=folder, + ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path( + organization: str, + ) -> str: """Returns a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format( + organization=organization, + ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path( + project: str, + ) -> str: """Returns a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format( + project=project, + ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path( + project: str, + location: str, + ) -> str: """Returns a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -286,9 +331,13 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_client_cert not in ("true", "false"): - raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) if use_mtls_endpoint not in ("auto", "never", "always"): - raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Figure out the client cert source to use. client_cert_source = None @@ -301,19 +350,23 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio # Figure out which api endpoint to use. if client_options.api_endpoint is not None: api_endpoint = client_options.api_endpoint - elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): api_endpoint = cls.DEFAULT_MTLS_ENDPOINT else: api_endpoint = cls.DEFAULT_ENDPOINT return api_endpoint, client_cert_source - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, ModelServiceTransport]] = None, - client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, ModelServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the model service client. Args: @@ -357,11 +410,15 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() client_options = cast(client_options_lib.ClientOptions, client_options) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options + ) api_key_value = getattr(client_options, "api_key", None) if api_key_value and credentials: - raise ValueError("client_options.api_key and credentials are mutually exclusive") + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -369,8 +426,10 @@ def __init__(self, *, if isinstance(transport, ModelServiceTransport): # transport is a ModelServiceTransport instance. if credentials or client_options.credentials_file or api_key_value: - raise ValueError("When providing a transport instance, " - "provide its credentials directly.") + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " @@ -380,8 +439,12 @@ def __init__(self, *, else: import google.auth._default # type: ignore - if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): - credentials = google.auth._default.get_api_key_credentials(api_key_value) + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) Transport = type(self).get_transport_class(transport) self._transport = Transport( @@ -396,14 +459,15 @@ def __init__(self, *, api_audience=client_options.api_audience, ) - def get_model(self, - request: Optional[Union[model_service.GetModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> model.Model: + def get_model( + self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Gets information about a specific Model. .. code-block:: python @@ -464,8 +528,10 @@ def sample_get_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetModelRequest. @@ -482,12 +548,10 @@ def sample_get_model(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.get_model] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -501,15 +565,16 @@ def sample_get_model(): # Done; return the response. return response - def list_models(self, - request: Optional[Union[model_service.ListModelsRequest, dict]] = None, - *, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListModelsPager: + def list_models( + self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: r"""Lists models available through the API. .. code-block:: python @@ -586,8 +651,10 @@ def sample_list_models(): # gotten any keyword arguments that map to the request. has_flattened_params = any([page_size, page_token]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListModelsRequest. @@ -626,14 +693,15 @@ def sample_list_models(): # Done; return the response. return response - def get_tuned_model(self, - request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> tuned_model.TunedModel: + def get_tuned_model( + self, + request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tuned_model.TunedModel: r"""Gets information about a specific TunedModel. .. code-block:: python @@ -691,8 +759,10 @@ def sample_get_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.GetTunedModelRequest. @@ -709,12 +779,10 @@ def sample_get_tuned_model(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.get_tuned_model] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -728,15 +796,16 @@ def sample_get_tuned_model(): # Done; return the response. return response - def list_tuned_models(self, - request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, - *, - page_size: Optional[int] = None, - page_token: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListTunedModelsPager: + def list_tuned_models( + self, + request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTunedModelsPager: r"""Lists tuned models owned by the user. .. code-block:: python @@ -814,8 +883,10 @@ def sample_list_tuned_models(): # gotten any keyword arguments that map to the request. has_flattened_params = any([page_size, page_token]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.ListTunedModelsRequest. @@ -854,15 +925,16 @@ def sample_list_tuned_models(): # Done; return the response. return response - def create_tuned_model(self, - request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, - *, - tuned_model: Optional[gag_tuned_model.TunedModel] = None, - tuned_model_id: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> operation.Operation: + def create_tuned_model( + self, + request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + tuned_model_id: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: r"""Creates a tuned model. Intermediate tuning progress (if any) is accessed through the [google.longrunning.Operations] service. @@ -943,8 +1015,10 @@ def sample_create_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([tuned_model, tuned_model_id]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.CreateTunedModelRequest. @@ -982,15 +1056,16 @@ def sample_create_tuned_model(): # Done; return the response. return response - def update_tuned_model(self, - request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, - *, - tuned_model: Optional[gag_tuned_model.TunedModel] = None, - update_mask: Optional[field_mask_pb2.FieldMask] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gag_tuned_model.TunedModel: + def update_tuned_model( + self, + request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_tuned_model.TunedModel: r"""Updates a tuned model. .. code-block:: python @@ -1055,8 +1130,10 @@ def sample_update_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([tuned_model, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.UpdateTunedModelRequest. @@ -1075,12 +1152,12 @@ def sample_update_tuned_model(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.update_tuned_model] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("tuned_model.name", request.tuned_model.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("tuned_model.name", request.tuned_model.name),) + ), ) # Send the request. @@ -1094,14 +1171,15 @@ def sample_update_tuned_model(): # Done; return the response. return response - def delete_tuned_model(self, - request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def delete_tuned_model( + self, + request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes a tuned model. .. code-block:: python @@ -1148,8 +1226,10 @@ def sample_delete_tuned_model(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a model_service.DeleteTunedModelRequest. @@ -1166,12 +1246,10 @@ def sample_delete_tuned_model(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.delete_tuned_model] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -1196,18 +1274,9 @@ def __exit__(self, type, value, traceback): self.transport.close() +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) - - - - - - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "ModelServiceClient", -) +__all__ = ("ModelServiceClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py similarity index 84% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py index ede1634e0b87..a0c9b4f1c151 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/pagers.py @@ -13,11 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, AsyncIterator, Awaitable, Callable, Sequence, Tuple, Optional, Iterator +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + Sequence, + Tuple, +) -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3.types import model, model_service, tuned_model class ListModelsPager: @@ -37,12 +44,15 @@ class ListModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListModelsResponse], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -76,7 +86,7 @@ def __iter__(self) -> Iterator[model.Model]: yield from page.models def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListModelsAsyncPager: @@ -96,12 +106,15 @@ class ListModelsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListModelsResponse]], - request: model_service.ListModelsRequest, - response: model_service.ListModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiates the pager. Args: @@ -129,6 +142,7 @@ async def pages(self) -> AsyncIterator[model_service.ListModelsResponse]: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response + def __aiter__(self) -> AsyncIterator[model.Model]: async def async_generator(): async for page in self.pages: @@ -138,7 +152,7 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListTunedModelsPager: @@ -158,12 +172,15 @@ class ListTunedModelsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., model_service.ListTunedModelsResponse], - request: model_service.ListTunedModelsRequest, - response: model_service.ListTunedModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., model_service.ListTunedModelsResponse], + request: model_service.ListTunedModelsRequest, + response: model_service.ListTunedModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -197,7 +214,7 @@ def __iter__(self) -> Iterator[tuned_model.TunedModel]: yield from page.tuned_models def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListTunedModelsAsyncPager: @@ -217,12 +234,15 @@ class ListTunedModelsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[model_service.ListTunedModelsResponse]], - request: model_service.ListTunedModelsRequest, - response: model_service.ListTunedModelsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListTunedModelsResponse]], + request: model_service.ListTunedModelsRequest, + response: model_service.ListTunedModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiates the pager. Args: @@ -250,6 +270,7 @@ async def pages(self) -> AsyncIterator[model_service.ListTunedModelsResponse]: self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response + def __aiter__(self) -> AsyncIterator[tuned_model.TunedModel]: async def async_generator(): async for page in self.pages: @@ -259,4 +280,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py similarity index 68% rename from owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py index c51cadf4ba09..1b430a25489e 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/model_service/transports/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/__init__.py @@ -19,20 +19,18 @@ from .base import ModelServiceTransport from .grpc import ModelServiceGrpcTransport from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport -from .rest import ModelServiceRestTransport -from .rest import ModelServiceRestInterceptor - +from .rest import ModelServiceRestInterceptor, ModelServiceRestTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] -_transport_registry['grpc'] = ModelServiceGrpcTransport -_transport_registry['grpc_asyncio'] = ModelServiceGrpcAsyncIOTransport -_transport_registry['rest'] = ModelServiceRestTransport +_transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport +_transport_registry["rest"] = ModelServiceRestTransport __all__ = ( - 'ModelServiceTransport', - 'ModelServiceGrpcTransport', - 'ModelServiceGrpcAsyncIOTransport', - 'ModelServiceRestTransport', - 'ModelServiceRestInterceptor', + "ModelServiceTransport", + "ModelServiceGrpcTransport", + "ModelServiceGrpcAsyncIOTransport", + "ModelServiceRestTransport", + "ModelServiceRestInterceptor", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py similarity index 62% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py index dcc5074c3ae7..218683129258 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/base.py @@ -16,46 +16,46 @@ import abc from typing import Awaitable, Callable, Dict, Optional, Sequence, Union -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version - -import google.auth # type: ignore import google.api_core from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 +from google.api_core import gapic_v1, operations_v1 from google.api_core import retry as retries -from google.api_core import operations_v1 +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.protobuf import empty_pb2 # type: ignore -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import empty_pb2 # type: ignore +from google.ai.generativelanguage_v1beta3.types import model, model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) class ModelServiceTransport(abc.ABC): """Abstract transport class for ModelService.""" - AUTH_SCOPES = ( - ) + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" - DEFAULT_HOST: str = 'generativelanguage.googleapis.com' def __init__( - self, *, - host: str = DEFAULT_HOST, - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - **kwargs, - ) -> None: + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -89,30 +89,38 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = google.auth.load_credentials_from_file( - credentials_file, - **scopes_kwargs, - quota_project_id=quota_project_id - ) + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) # Don't apply audience if the credentials file passed from user. if hasattr(credentials, "with_gdch_audience"): - credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) # If the credentials are service account credentials, then always try to use self signed JWT. - if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): credentials = credentials.with_always_use_jwt_access(True) # Save the credentials. self._credentials = credentials # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host def _prep_wrapped_messages(self, client_info): @@ -153,14 +161,14 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } + } def close(self): """Closes resources associated with the transport. - .. warning:: - Only call this method if the transport is NOT shared - with other clients - this may cause errors in other clients! + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! """ raise NotImplementedError() @@ -170,66 +178,71 @@ def operations_client(self): raise NotImplementedError() @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - Union[ - model.Model, - Awaitable[model.Model] - ]]: + def get_model( + self, + ) -> Callable[ + [model_service.GetModelRequest], Union[model.Model, Awaitable[model.Model]] + ]: raise NotImplementedError() @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - Union[ - model_service.ListModelsResponse, - Awaitable[model_service.ListModelsResponse] - ]]: + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], + Union[ + model_service.ListModelsResponse, + Awaitable[model_service.ListModelsResponse], + ], + ]: raise NotImplementedError() @property - def get_tuned_model(self) -> Callable[ - [model_service.GetTunedModelRequest], - Union[ - tuned_model.TunedModel, - Awaitable[tuned_model.TunedModel] - ]]: + def get_tuned_model( + self, + ) -> Callable[ + [model_service.GetTunedModelRequest], + Union[tuned_model.TunedModel, Awaitable[tuned_model.TunedModel]], + ]: raise NotImplementedError() @property - def list_tuned_models(self) -> Callable[ - [model_service.ListTunedModelsRequest], - Union[ - model_service.ListTunedModelsResponse, - Awaitable[model_service.ListTunedModelsResponse] - ]]: + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], + Union[ + model_service.ListTunedModelsResponse, + Awaitable[model_service.ListTunedModelsResponse], + ], + ]: raise NotImplementedError() @property - def create_tuned_model(self) -> Callable[ - [model_service.CreateTunedModelRequest], - Union[ - operations_pb2.Operation, - Awaitable[operations_pb2.Operation] - ]]: + def create_tuned_model( + self, + ) -> Callable[ + [model_service.CreateTunedModelRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: raise NotImplementedError() @property - def update_tuned_model(self) -> Callable[ - [model_service.UpdateTunedModelRequest], - Union[ - gag_tuned_model.TunedModel, - Awaitable[gag_tuned_model.TunedModel] - ]]: + def update_tuned_model( + self, + ) -> Callable[ + [model_service.UpdateTunedModelRequest], + Union[gag_tuned_model.TunedModel, Awaitable[gag_tuned_model.TunedModel]], + ]: raise NotImplementedError() @property - def delete_tuned_model(self) -> Callable[ - [model_service.DeleteTunedModelRequest], - Union[ - empty_pb2.Empty, - Awaitable[empty_pb2.Empty] - ]]: + def delete_tuned_model( + self, + ) -> Callable[ + [model_service.DeleteTunedModelRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: raise NotImplementedError() @property @@ -237,6 +250,4 @@ def kind(self) -> str: raise NotImplementedError() -__all__ = ( - 'ModelServiceTransport', -) +__all__ = ("ModelServiceTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py similarity index 77% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py index c6a0e54df2da..f8a5be574b10 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc.py @@ -13,25 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import grpc_helpers -from google.api_core import operations_v1 -from google.api_core import gapic_v1 -import google.auth # type: ignore +from google.api_core import gapic_v1, grpc_helpers, operations_v1 +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore import grpc # type: ignore -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import empty_pb2 # type: ignore -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from google.ai.generativelanguage_v1beta3.types import model, model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model + +from .base import DEFAULT_CLIENT_INFO, ModelServiceTransport class ModelServiceGrpcTransport(ModelServiceTransport): @@ -47,23 +44,26 @@ class ModelServiceGrpcTransport(ModelServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -183,13 +183,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -224,13 +226,12 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property @@ -242,17 +243,13 @@ def operations_client(self) -> operations_v1.OperationsClient: """ # Quick check: Only create a new client if we do not already have one. if self._operations_client is None: - self._operations_client = operations_v1.OperationsClient( - self.grpc_channel - ) + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. return self._operations_client @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - model.Model]: + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: r"""Return a callable for the get model method over gRPC. Gets information about a specific Model. @@ -267,18 +264,18 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - model_service.ListModelsResponse]: + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: r"""Return a callable for the list models method over gRPC. Lists models available through the API. @@ -293,18 +290,18 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def get_tuned_model(self) -> Callable[ - [model_service.GetTunedModelRequest], - tuned_model.TunedModel]: + def get_tuned_model( + self, + ) -> Callable[[model_service.GetTunedModelRequest], tuned_model.TunedModel]: r"""Return a callable for the get tuned model method over gRPC. Gets information about a specific TunedModel. @@ -319,18 +316,20 @@ def get_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_tuned_model' not in self._stubs: - self._stubs['get_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/GetTunedModel', + if "get_tuned_model" not in self._stubs: + self._stubs["get_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/GetTunedModel", request_serializer=model_service.GetTunedModelRequest.serialize, response_deserializer=tuned_model.TunedModel.deserialize, ) - return self._stubs['get_tuned_model'] + return self._stubs["get_tuned_model"] @property - def list_tuned_models(self) -> Callable[ - [model_service.ListTunedModelsRequest], - model_service.ListTunedModelsResponse]: + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], model_service.ListTunedModelsResponse + ]: r"""Return a callable for the list tuned models method over gRPC. Lists tuned models owned by the user. @@ -345,18 +344,18 @@ def list_tuned_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_tuned_models' not in self._stubs: - self._stubs['list_tuned_models'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/ListTunedModels', + if "list_tuned_models" not in self._stubs: + self._stubs["list_tuned_models"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/ListTunedModels", request_serializer=model_service.ListTunedModelsRequest.serialize, response_deserializer=model_service.ListTunedModelsResponse.deserialize, ) - return self._stubs['list_tuned_models'] + return self._stubs["list_tuned_models"] @property - def create_tuned_model(self) -> Callable[ - [model_service.CreateTunedModelRequest], - operations_pb2.Operation]: + def create_tuned_model( + self, + ) -> Callable[[model_service.CreateTunedModelRequest], operations_pb2.Operation]: r"""Return a callable for the create tuned model method over gRPC. Creates a tuned model. Intermediate tuning progress (if any) is @@ -376,18 +375,18 @@ def create_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_tuned_model' not in self._stubs: - self._stubs['create_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/CreateTunedModel', + if "create_tuned_model" not in self._stubs: + self._stubs["create_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/CreateTunedModel", request_serializer=model_service.CreateTunedModelRequest.serialize, response_deserializer=operations_pb2.Operation.FromString, ) - return self._stubs['create_tuned_model'] + return self._stubs["create_tuned_model"] @property - def update_tuned_model(self) -> Callable[ - [model_service.UpdateTunedModelRequest], - gag_tuned_model.TunedModel]: + def update_tuned_model( + self, + ) -> Callable[[model_service.UpdateTunedModelRequest], gag_tuned_model.TunedModel]: r"""Return a callable for the update tuned model method over gRPC. Updates a tuned model. @@ -402,18 +401,18 @@ def update_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_tuned_model' not in self._stubs: - self._stubs['update_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/UpdateTunedModel', + if "update_tuned_model" not in self._stubs: + self._stubs["update_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/UpdateTunedModel", request_serializer=model_service.UpdateTunedModelRequest.serialize, response_deserializer=gag_tuned_model.TunedModel.deserialize, ) - return self._stubs['update_tuned_model'] + return self._stubs["update_tuned_model"] @property - def delete_tuned_model(self) -> Callable[ - [model_service.DeleteTunedModelRequest], - empty_pb2.Empty]: + def delete_tuned_model( + self, + ) -> Callable[[model_service.DeleteTunedModelRequest], empty_pb2.Empty]: r"""Return a callable for the delete tuned model method over gRPC. Deletes a tuned model. @@ -428,13 +427,13 @@ def delete_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_tuned_model' not in self._stubs: - self._stubs['delete_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/DeleteTunedModel', + if "delete_tuned_model" not in self._stubs: + self._stubs["delete_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/DeleteTunedModel", request_serializer=model_service.DeleteTunedModelRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, ) - return self._stubs['delete_tuned_model'] + return self._stubs["delete_tuned_model"] def close(self): self.grpc_channel.close() @@ -444,6 +443,4 @@ def kind(self) -> str: return "grpc" -__all__ = ( - 'ModelServiceGrpcTransport', -) +__all__ = ("ModelServiceGrpcTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py similarity index 77% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py index c8426f6c3910..a9fcf9077b14 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/grpc_asyncio.py @@ -13,25 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.api_core import operations_v1 -from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import gapic_v1, grpc_helpers_async, operations_v1 +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import empty_pb2 # type: ignore -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from google.ai.generativelanguage_v1beta3.types import model, model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model + +from .base import DEFAULT_CLIENT_INFO, ModelServiceTransport from .grpc import ModelServiceGrpcTransport @@ -53,13 +50,15 @@ class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -90,24 +89,26 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -253,9 +254,9 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: return self._operations_client @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - Awaitable[model.Model]]: + def get_model( + self, + ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: r"""Return a callable for the get model method over gRPC. Gets information about a specific Model. @@ -270,18 +271,20 @@ def get_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_model' not in self._stubs: - self._stubs['get_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/GetModel', + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/GetModel", request_serializer=model_service.GetModelRequest.serialize, response_deserializer=model.Model.deserialize, ) - return self._stubs['get_model'] + return self._stubs["get_model"] @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - Awaitable[model_service.ListModelsResponse]]: + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] + ]: r"""Return a callable for the list models method over gRPC. Lists models available through the API. @@ -296,18 +299,20 @@ def list_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_models' not in self._stubs: - self._stubs['list_models'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/ListModels', + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/ListModels", request_serializer=model_service.ListModelsRequest.serialize, response_deserializer=model_service.ListModelsResponse.deserialize, ) - return self._stubs['list_models'] + return self._stubs["list_models"] @property - def get_tuned_model(self) -> Callable[ - [model_service.GetTunedModelRequest], - Awaitable[tuned_model.TunedModel]]: + def get_tuned_model( + self, + ) -> Callable[ + [model_service.GetTunedModelRequest], Awaitable[tuned_model.TunedModel] + ]: r"""Return a callable for the get tuned model method over gRPC. Gets information about a specific TunedModel. @@ -322,18 +327,21 @@ def get_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_tuned_model' not in self._stubs: - self._stubs['get_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/GetTunedModel', + if "get_tuned_model" not in self._stubs: + self._stubs["get_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/GetTunedModel", request_serializer=model_service.GetTunedModelRequest.serialize, response_deserializer=tuned_model.TunedModel.deserialize, ) - return self._stubs['get_tuned_model'] + return self._stubs["get_tuned_model"] @property - def list_tuned_models(self) -> Callable[ - [model_service.ListTunedModelsRequest], - Awaitable[model_service.ListTunedModelsResponse]]: + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], + Awaitable[model_service.ListTunedModelsResponse], + ]: r"""Return a callable for the list tuned models method over gRPC. Lists tuned models owned by the user. @@ -348,18 +356,20 @@ def list_tuned_models(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_tuned_models' not in self._stubs: - self._stubs['list_tuned_models'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/ListTunedModels', + if "list_tuned_models" not in self._stubs: + self._stubs["list_tuned_models"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/ListTunedModels", request_serializer=model_service.ListTunedModelsRequest.serialize, response_deserializer=model_service.ListTunedModelsResponse.deserialize, ) - return self._stubs['list_tuned_models'] + return self._stubs["list_tuned_models"] @property - def create_tuned_model(self) -> Callable[ - [model_service.CreateTunedModelRequest], - Awaitable[operations_pb2.Operation]]: + def create_tuned_model( + self, + ) -> Callable[ + [model_service.CreateTunedModelRequest], Awaitable[operations_pb2.Operation] + ]: r"""Return a callable for the create tuned model method over gRPC. Creates a tuned model. Intermediate tuning progress (if any) is @@ -379,18 +389,20 @@ def create_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_tuned_model' not in self._stubs: - self._stubs['create_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/CreateTunedModel', + if "create_tuned_model" not in self._stubs: + self._stubs["create_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/CreateTunedModel", request_serializer=model_service.CreateTunedModelRequest.serialize, response_deserializer=operations_pb2.Operation.FromString, ) - return self._stubs['create_tuned_model'] + return self._stubs["create_tuned_model"] @property - def update_tuned_model(self) -> Callable[ - [model_service.UpdateTunedModelRequest], - Awaitable[gag_tuned_model.TunedModel]]: + def update_tuned_model( + self, + ) -> Callable[ + [model_service.UpdateTunedModelRequest], Awaitable[gag_tuned_model.TunedModel] + ]: r"""Return a callable for the update tuned model method over gRPC. Updates a tuned model. @@ -405,18 +417,18 @@ def update_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_tuned_model' not in self._stubs: - self._stubs['update_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/UpdateTunedModel', + if "update_tuned_model" not in self._stubs: + self._stubs["update_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/UpdateTunedModel", request_serializer=model_service.UpdateTunedModelRequest.serialize, response_deserializer=gag_tuned_model.TunedModel.deserialize, ) - return self._stubs['update_tuned_model'] + return self._stubs["update_tuned_model"] @property - def delete_tuned_model(self) -> Callable[ - [model_service.DeleteTunedModelRequest], - Awaitable[empty_pb2.Empty]]: + def delete_tuned_model( + self, + ) -> Callable[[model_service.DeleteTunedModelRequest], Awaitable[empty_pb2.Empty]]: r"""Return a callable for the delete tuned model method over gRPC. Deletes a tuned model. @@ -431,18 +443,16 @@ def delete_tuned_model(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_tuned_model' not in self._stubs: - self._stubs['delete_tuned_model'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.ModelService/DeleteTunedModel', + if "delete_tuned_model" not in self._stubs: + self._stubs["delete_tuned_model"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.ModelService/DeleteTunedModel", request_serializer=model_service.DeleteTunedModelRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, ) - return self._stubs['delete_tuned_model'] + return self._stubs["delete_tuned_model"] def close(self): return self.grpc_channel.close() -__all__ = ( - 'ModelServiceGrpcAsyncIOTransport', -) +__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py similarity index 69% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py index 9d43aea96036..b5afef195a17 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/model_service/transports/rest.py @@ -14,25 +14,27 @@ # limitations under the License. # -from google.auth.transport.requests import AuthorizedSession # type: ignore +import dataclasses import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import ( + gapic_v1, + operations_v1, + path_template, + rest_helpers, + rest_streaming, +) from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -from google.api_core import rest_helpers -from google.api_core import rest_streaming -from google.api_core import path_template -from google.api_core import gapic_v1 - +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore from google.protobuf import json_format -from google.api_core import operations_v1 +import grpc # type: ignore from requests import __version__ as requests_version -import dataclasses -import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import warnings try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] @@ -40,15 +42,15 @@ OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model -from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model -from google.protobuf import empty_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore -from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1beta3.types import model, model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .base import ModelServiceTransport DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, @@ -129,7 +131,12 @@ def post_update_tuned_model(self, response): """ - def pre_create_tuned_model(self, request: model_service.CreateTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.CreateTunedModelRequest, Sequence[Tuple[str, str]]]: + + def pre_create_tuned_model( + self, + request: model_service.CreateTunedModelRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[model_service.CreateTunedModelRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for create_tuned_model Override in a subclass to manipulate the request or metadata @@ -137,7 +144,9 @@ def pre_create_tuned_model(self, request: model_service.CreateTunedModelRequest, """ return request, metadata - def post_create_tuned_model(self, response: operations_pb2.Operation) -> operations_pb2.Operation: + def post_create_tuned_model( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: """Post-rpc interceptor for create_tuned_model Override in a subclass to manipulate the response @@ -145,7 +154,12 @@ def post_create_tuned_model(self, response: operations_pb2.Operation) -> operati it is returned to user code. """ return response - def pre_delete_tuned_model(self, request: model_service.DeleteTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.DeleteTunedModelRequest, Sequence[Tuple[str, str]]]: + + def pre_delete_tuned_model( + self, + request: model_service.DeleteTunedModelRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[model_service.DeleteTunedModelRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for delete_tuned_model Override in a subclass to manipulate the request or metadata @@ -153,7 +167,11 @@ def pre_delete_tuned_model(self, request: model_service.DeleteTunedModelRequest, """ return request, metadata - def pre_get_model(self, request: model_service.GetModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.GetModelRequest, Sequence[Tuple[str, str]]]: + def pre_get_model( + self, + request: model_service.GetModelRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[model_service.GetModelRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for get_model Override in a subclass to manipulate the request or metadata @@ -169,7 +187,12 @@ def post_get_model(self, response: model.Model) -> model.Model: it is returned to user code. """ return response - def pre_get_tuned_model(self, request: model_service.GetTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.GetTunedModelRequest, Sequence[Tuple[str, str]]]: + + def pre_get_tuned_model( + self, + request: model_service.GetTunedModelRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[model_service.GetTunedModelRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for get_tuned_model Override in a subclass to manipulate the request or metadata @@ -177,7 +200,9 @@ def pre_get_tuned_model(self, request: model_service.GetTunedModelRequest, metad """ return request, metadata - def post_get_tuned_model(self, response: tuned_model.TunedModel) -> tuned_model.TunedModel: + def post_get_tuned_model( + self, response: tuned_model.TunedModel + ) -> tuned_model.TunedModel: """Post-rpc interceptor for get_tuned_model Override in a subclass to manipulate the response @@ -185,7 +210,12 @@ def post_get_tuned_model(self, response: tuned_model.TunedModel) -> tuned_model. it is returned to user code. """ return response - def pre_list_models(self, request: model_service.ListModelsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.ListModelsRequest, Sequence[Tuple[str, str]]]: + + def pre_list_models( + self, + request: model_service.ListModelsRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[model_service.ListModelsRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for list_models Override in a subclass to manipulate the request or metadata @@ -193,7 +223,9 @@ def pre_list_models(self, request: model_service.ListModelsRequest, metadata: Se """ return request, metadata - def post_list_models(self, response: model_service.ListModelsResponse) -> model_service.ListModelsResponse: + def post_list_models( + self, response: model_service.ListModelsResponse + ) -> model_service.ListModelsResponse: """Post-rpc interceptor for list_models Override in a subclass to manipulate the response @@ -201,7 +233,12 @@ def post_list_models(self, response: model_service.ListModelsResponse) -> model_ it is returned to user code. """ return response - def pre_list_tuned_models(self, request: model_service.ListTunedModelsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.ListTunedModelsRequest, Sequence[Tuple[str, str]]]: + + def pre_list_tuned_models( + self, + request: model_service.ListTunedModelsRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[model_service.ListTunedModelsRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for list_tuned_models Override in a subclass to manipulate the request or metadata @@ -209,7 +246,9 @@ def pre_list_tuned_models(self, request: model_service.ListTunedModelsRequest, m """ return request, metadata - def post_list_tuned_models(self, response: model_service.ListTunedModelsResponse) -> model_service.ListTunedModelsResponse: + def post_list_tuned_models( + self, response: model_service.ListTunedModelsResponse + ) -> model_service.ListTunedModelsResponse: """Post-rpc interceptor for list_tuned_models Override in a subclass to manipulate the response @@ -217,7 +256,12 @@ def post_list_tuned_models(self, response: model_service.ListTunedModelsResponse it is returned to user code. """ return response - def pre_update_tuned_model(self, request: model_service.UpdateTunedModelRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[model_service.UpdateTunedModelRequest, Sequence[Tuple[str, str]]]: + + def pre_update_tuned_model( + self, + request: model_service.UpdateTunedModelRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[model_service.UpdateTunedModelRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for update_tuned_model Override in a subclass to manipulate the request or metadata @@ -225,7 +269,9 @@ def pre_update_tuned_model(self, request: model_service.UpdateTunedModelRequest, """ return request, metadata - def post_update_tuned_model(self, response: gag_tuned_model.TunedModel) -> gag_tuned_model.TunedModel: + def post_update_tuned_model( + self, response: gag_tuned_model.TunedModel + ) -> gag_tuned_model.TunedModel: """Post-rpc interceptor for update_tuned_model Override in a subclass to manipulate the response @@ -256,20 +302,21 @@ class ModelServiceRestTransport(ModelServiceTransport): """ - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - client_cert_source_for_mtls: Optional[Callable[[ - ], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - url_scheme: str = 'https', - interceptor: Optional[ModelServiceRestInterceptor] = None, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[ModelServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -308,7 +355,9 @@ def __init__(self, *, # credentials object maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) if maybe_url_match is None: - raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER url_match_items = maybe_url_match.groupdict() @@ -319,10 +368,11 @@ def __init__(self, *, credentials=credentials, client_info=client_info, always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience + api_audience=api_audience, ) self._session = AuthorizedSession( - self._credentials, default_host=self.DEFAULT_HOST) + self._credentials, default_host=self.DEFAULT_HOST + ) self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) @@ -338,18 +388,20 @@ def operations_client(self) -> operations_v1.AbstractOperationsClient: """ # Only create a new client if we do not already have one. if self._operations_client is None: - http_options: Dict[str, List[Dict[str, str]]] = { - } + http_options: Dict[str, List[Dict[str, str]]] = {} rest_transport = operations_v1.OperationsRestTransport( - host=self._host, - # use the credentials which are saved - credentials=self._credentials, - scopes=self._scopes, - http_options=http_options, - path_prefix="v1beta3") + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1beta3", + ) - self._operations_client = operations_v1.AbstractOperationsClient(transport=rest_transport) + self._operations_client = operations_v1.AbstractOperationsClient( + transport=rest_transport + ) # Return the client from cache. return self._operations_client @@ -358,19 +410,24 @@ class _CreateTunedModel(ModelServiceRestStub): def __hash__(self): return hash("CreateTunedModel") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: model_service.CreateTunedModelRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> operations_pb2.Operation: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: model_service.CreateTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operations_pb2.Operation: r"""Call the create tuned model method over HTTP. Args: @@ -390,46 +447,51 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/tunedModels', - 'body': 'tuned_model', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/tunedModels", + "body": "tuned_model", + }, ] - request, metadata = self._interceptor.pre_create_tuned_model(request, metadata) + request, metadata = self._interceptor.pre_create_tuned_model( + request, metadata + ) pb_request = model_service.CreateTunedModelRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -446,19 +508,24 @@ class _DeleteTunedModel(ModelServiceRestStub): def __hash__(self): return hash("DeleteTunedModel") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: model_service.DeleteTunedModelRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: model_service.DeleteTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): r"""Call the delete tuned model method over HTTP. Args: @@ -471,37 +538,42 @@ def __call__(self, sent along with the request as metadata. """ - http_options: List[Dict[str, str]] = [{ - 'method': 'delete', - 'uri': '/v1beta3/{name=tunedModels/*}', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1beta3/{name=tunedModels/*}", + }, ] - request, metadata = self._interceptor.pre_delete_tuned_model(request, metadata) + request, metadata = self._interceptor.pre_delete_tuned_model( + request, metadata + ) pb_request = model_service.DeleteTunedModelRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -512,19 +584,24 @@ class _GetModel(ModelServiceRestStub): def __hash__(self): return hash("GetModel") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: model_service.GetModelRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> model.Model: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: model_service.GetModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: r"""Call the get model method over HTTP. Args: @@ -544,37 +621,40 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta3/{name=models/*}', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta3/{name=models/*}", + }, ] request, metadata = self._interceptor.pre_get_model(request, metadata) pb_request = model_service.GetModelRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -593,19 +673,24 @@ class _GetTunedModel(ModelServiceRestStub): def __hash__(self): return hash("GetTunedModel") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: model_service.GetTunedModelRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> tuned_model.TunedModel: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: model_service.GetTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> tuned_model.TunedModel: r"""Call the get tuned model method over HTTP. Args: @@ -625,37 +710,40 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta3/{name=tunedModels/*}', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta3/{name=tunedModels/*}", + }, ] request, metadata = self._interceptor.pre_get_tuned_model(request, metadata) pb_request = model_service.GetTunedModelRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -674,12 +762,14 @@ class _ListModels(ModelServiceRestStub): def __hash__(self): return hash("ListModels") - def __call__(self, - request: model_service.ListModelsRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> model_service.ListModelsResponse: + def __call__( + self, + request: model_service.ListModelsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_service.ListModelsResponse: r"""Call the list models method over HTTP. Args: @@ -698,36 +788,39 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta3/models', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta3/models", + }, ] request, metadata = self._interceptor.pre_list_models(request, metadata) pb_request = model_service.ListModelsRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -746,12 +839,14 @@ class _ListTunedModels(ModelServiceRestStub): def __hash__(self): return hash("ListTunedModels") - def __call__(self, - request: model_service.ListTunedModelsRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> model_service.ListTunedModelsResponse: + def __call__( + self, + request: model_service.ListTunedModelsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_service.ListTunedModelsResponse: r"""Call the list tuned models method over HTTP. Args: @@ -770,36 +865,41 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta3/tunedModels', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta3/tunedModels", + }, ] - request, metadata = self._interceptor.pre_list_tuned_models(request, metadata) + request, metadata = self._interceptor.pre_list_tuned_models( + request, metadata + ) pb_request = model_service.ListTunedModelsRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -818,19 +918,26 @@ class _UpdateTunedModel(ModelServiceRestStub): def __hash__(self): return hash("UpdateTunedModel") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - "updateMask" : {}, } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask": {}, + } @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: model_service.UpdateTunedModelRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> gag_tuned_model.TunedModel: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: model_service.UpdateTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_tuned_model.TunedModel: r"""Call the update tuned model method over HTTP. Args: @@ -849,46 +956,51 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'patch', - 'uri': '/v1beta3/{tuned_model.name=tunedModels/*}', - 'body': 'tuned_model', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1beta3/{tuned_model.name=tunedModels/*}", + "body": "tuned_model", + }, ] - request, metadata = self._interceptor.pre_update_tuned_model(request, metadata) + request, metadata = self._interceptor.pre_update_tuned_model( + request, metadata + ) pb_request = model_service.UpdateTunedModelRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -904,60 +1016,60 @@ def __call__(self, return resp @property - def create_tuned_model(self) -> Callable[ - [model_service.CreateTunedModelRequest], - operations_pb2.Operation]: + def create_tuned_model( + self, + ) -> Callable[[model_service.CreateTunedModelRequest], operations_pb2.Operation]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._CreateTunedModel(self._session, self._host, self._interceptor) # type: ignore + return self._CreateTunedModel(self._session, self._host, self._interceptor) # type: ignore @property - def delete_tuned_model(self) -> Callable[ - [model_service.DeleteTunedModelRequest], - empty_pb2.Empty]: + def delete_tuned_model( + self, + ) -> Callable[[model_service.DeleteTunedModelRequest], empty_pb2.Empty]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._DeleteTunedModel(self._session, self._host, self._interceptor) # type: ignore + return self._DeleteTunedModel(self._session, self._host, self._interceptor) # type: ignore @property - def get_model(self) -> Callable[ - [model_service.GetModelRequest], - model.Model]: + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._GetModel(self._session, self._host, self._interceptor) # type: ignore + return self._GetModel(self._session, self._host, self._interceptor) # type: ignore @property - def get_tuned_model(self) -> Callable[ - [model_service.GetTunedModelRequest], - tuned_model.TunedModel]: + def get_tuned_model( + self, + ) -> Callable[[model_service.GetTunedModelRequest], tuned_model.TunedModel]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._GetTunedModel(self._session, self._host, self._interceptor) # type: ignore + return self._GetTunedModel(self._session, self._host, self._interceptor) # type: ignore @property - def list_models(self) -> Callable[ - [model_service.ListModelsRequest], - model_service.ListModelsResponse]: + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._ListModels(self._session, self._host, self._interceptor) # type: ignore + return self._ListModels(self._session, self._host, self._interceptor) # type: ignore @property - def list_tuned_models(self) -> Callable[ - [model_service.ListTunedModelsRequest], - model_service.ListTunedModelsResponse]: + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], model_service.ListTunedModelsResponse + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._ListTunedModels(self._session, self._host, self._interceptor) # type: ignore + return self._ListTunedModels(self._session, self._host, self._interceptor) # type: ignore @property - def update_tuned_model(self) -> Callable[ - [model_service.UpdateTunedModelRequest], - gag_tuned_model.TunedModel]: + def update_tuned_model( + self, + ) -> Callable[[model_service.UpdateTunedModelRequest], gag_tuned_model.TunedModel]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._UpdateTunedModel(self._session, self._host, self._interceptor) # type: ignore + return self._UpdateTunedModel(self._session, self._host, self._interceptor) # type: ignore @property def kind(self) -> str: @@ -967,6 +1079,4 @@ def close(self): self._session.close() -__all__=( - 'ModelServiceRestTransport', -) +__all__ = ("ModelServiceRestTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py similarity index 91% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py index eb61a596594a..7cd02e1fc232 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/__init__.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .client import PermissionServiceClient from .async_client import PermissionServiceAsyncClient +from .client import PermissionServiceClient __all__ = ( - 'PermissionServiceClient', - 'PermissionServiceAsyncClient', + "PermissionServiceClient", + "PermissionServiceAsyncClient", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py similarity index 84% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py index f83af219c903..9b9faceba44c 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py @@ -16,31 +16,43 @@ from collections import OrderedDict import functools import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) -from google.api_core.client_options import ClientOptions from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore + from google.ai.generativelanguage_v1beta3.services.permission_service import pagers -from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission_service -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import field_mask_pb2 # type: ignore -from .transports.base import PermissionServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import PermissionServiceGrpcAsyncIOTransport + from .client import PermissionServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, PermissionServiceTransport +from .transports.grpc_asyncio import PermissionServiceGrpcAsyncIOTransport class PermissionServiceAsyncClient: @@ -56,17 +68,33 @@ class PermissionServiceAsyncClient: permission_path = staticmethod(PermissionServiceClient.permission_path) parse_permission_path = staticmethod(PermissionServiceClient.parse_permission_path) tuned_model_path = staticmethod(PermissionServiceClient.tuned_model_path) - parse_tuned_model_path = staticmethod(PermissionServiceClient.parse_tuned_model_path) - common_billing_account_path = staticmethod(PermissionServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(PermissionServiceClient.parse_common_billing_account_path) + parse_tuned_model_path = staticmethod( + PermissionServiceClient.parse_tuned_model_path + ) + common_billing_account_path = staticmethod( + PermissionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PermissionServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(PermissionServiceClient.common_folder_path) - parse_common_folder_path = staticmethod(PermissionServiceClient.parse_common_folder_path) - common_organization_path = staticmethod(PermissionServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(PermissionServiceClient.parse_common_organization_path) + parse_common_folder_path = staticmethod( + PermissionServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + PermissionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PermissionServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(PermissionServiceClient.common_project_path) - parse_common_project_path = staticmethod(PermissionServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + PermissionServiceClient.parse_common_project_path + ) common_location_path = staticmethod(PermissionServiceClient.common_location_path) - parse_common_location_path = staticmethod(PermissionServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + PermissionServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -102,7 +130,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -144,14 +174,18 @@ def transport(self) -> PermissionServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(PermissionServiceClient).get_transport_class, type(PermissionServiceClient)) - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, PermissionServiceTransport] = "grpc_asyncio", - client_options: Optional[ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + get_transport_class = functools.partial( + type(PermissionServiceClient).get_transport_class, type(PermissionServiceClient) + ) + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, PermissionServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the permission service client. Args: @@ -189,18 +223,20 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def create_permission(self, - request: Optional[Union[permission_service.CreatePermissionRequest, dict]] = None, - *, - parent: Optional[str] = None, - permission: Optional[gag_permission.Permission] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gag_permission.Permission: + async def create_permission( + self, + request: Optional[ + Union[permission_service.CreatePermissionRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + permission: Optional[gag_permission.Permission] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: r"""Create a permission to a specific resource. .. code-block:: python @@ -282,8 +318,10 @@ async def sample_create_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, permission]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = permission_service.CreatePermissionRequest(request) @@ -305,9 +343,7 @@ async def sample_create_permission(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("parent", request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -321,14 +357,15 @@ async def sample_create_permission(): # Done; return the response. return response - async def get_permission(self, - request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> permission.Permission: + async def get_permission( + self, + request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission.Permission: r"""Gets information about a specific Permission. .. code-block:: python @@ -408,8 +445,10 @@ async def sample_get_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = permission_service.GetPermissionRequest(request) @@ -429,9 +468,7 @@ async def sample_get_permission(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -445,14 +482,17 @@ async def sample_get_permission(): # Done; return the response. return response - async def list_permissions(self, - request: Optional[Union[permission_service.ListPermissionsRequest, dict]] = None, - *, - parent: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListPermissionsAsyncPager: + async def list_permissions( + self, + request: Optional[ + Union[permission_service.ListPermissionsRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListPermissionsAsyncPager: r"""Lists permissions for the specific resource. .. code-block:: python @@ -512,8 +552,10 @@ async def sample_list_permissions(): # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = permission_service.ListPermissionsRequest(request) @@ -533,9 +575,7 @@ async def sample_list_permissions(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("parent", request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -558,15 +598,18 @@ async def sample_list_permissions(): # Done; return the response. return response - async def update_permission(self, - request: Optional[Union[permission_service.UpdatePermissionRequest, dict]] = None, - *, - permission: Optional[gag_permission.Permission] = None, - update_mask: Optional[field_mask_pb2.FieldMask] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gag_permission.Permission: + async def update_permission( + self, + request: Optional[ + Union[permission_service.UpdatePermissionRequest, dict] + ] = None, + *, + permission: Optional[gag_permission.Permission] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: r"""Updates the permission. .. code-block:: python @@ -652,8 +695,10 @@ async def sample_update_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([permission, update_mask]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = permission_service.UpdatePermissionRequest(request) @@ -675,9 +720,9 @@ async def sample_update_permission(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("permission.name", request.permission.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("permission.name", request.permission.name),) + ), ) # Send the request. @@ -691,14 +736,17 @@ async def sample_update_permission(): # Done; return the response. return response - async def delete_permission(self, - request: Optional[Union[permission_service.DeletePermissionRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + async def delete_permission( + self, + request: Optional[ + Union[permission_service.DeletePermissionRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes the permission. .. code-block:: python @@ -745,8 +793,10 @@ async def sample_delete_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = permission_service.DeletePermissionRequest(request) @@ -766,9 +816,7 @@ async def sample_delete_permission(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -779,13 +827,16 @@ async def sample_delete_permission(): metadata=metadata, ) - async def transfer_ownership(self, - request: Optional[Union[permission_service.TransferOwnershipRequest, dict]] = None, - *, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> permission_service.TransferOwnershipResponse: + async def transfer_ownership( + self, + request: Optional[ + Union[permission_service.TransferOwnershipRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission_service.TransferOwnershipResponse: r"""Transfers ownership of the tuned model. This is the only way to change ownership of the tuned model. The current owner will be downgraded to writer @@ -846,9 +897,7 @@ async def sample_transfer_ownership(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -868,9 +917,10 @@ async def __aenter__(self) -> "PermissionServiceAsyncClient": async def __aexit__(self, exc_type, exc, tb): await self.transport.close() -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - -__all__ = ( - "PermissionServiceAsyncClient", +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ ) + + +__all__ = ("PermissionServiceAsyncClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/client.py similarity index 82% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/client.py index e25e91f72b5f..78bbe681b0cc 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/client.py @@ -16,32 +16,45 @@ from collections import OrderedDict import os import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from google.api_core import client_options as client_options_lib from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore + from google.ai.generativelanguage_v1beta3.services.permission_service import pagers -from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission_service -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import field_mask_pb2 # type: ignore -from .transports.base import PermissionServiceTransport, DEFAULT_CLIENT_INFO + +from .transports.base import DEFAULT_CLIENT_INFO, PermissionServiceTransport from .transports.grpc import PermissionServiceGrpcTransport from .transports.grpc_asyncio import PermissionServiceGrpcAsyncIOTransport from .transports.rest import PermissionServiceRestTransport @@ -54,14 +67,18 @@ class PermissionServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ - _transport_registry = OrderedDict() # type: Dict[str, Type[PermissionServiceTransport]] + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PermissionServiceTransport]] _transport_registry["grpc"] = PermissionServiceGrpcTransport _transport_registry["grpc_asyncio"] = PermissionServiceGrpcAsyncIOTransport _transport_registry["rest"] = PermissionServiceRestTransport - def get_transport_class(cls, - label: Optional[str] = None, - ) -> Type[PermissionServiceTransport]: + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[PermissionServiceTransport]: """Returns an appropriate transport class. Args: @@ -151,8 +168,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: PermissionServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) + credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials return cls(*args, **kwargs) @@ -169,84 +185,120 @@ def transport(self) -> PermissionServiceTransport: return self._transport @staticmethod - def permission_path(tuned_model: str,permission: str,) -> str: + def permission_path( + tuned_model: str, + permission: str, + ) -> str: """Returns a fully-qualified permission string.""" - return "tunedModels/{tuned_model}/permissions/{permission}".format(tuned_model=tuned_model, permission=permission, ) + return "tunedModels/{tuned_model}/permissions/{permission}".format( + tuned_model=tuned_model, + permission=permission, + ) @staticmethod - def parse_permission_path(path: str) -> Dict[str,str]: + def parse_permission_path(path: str) -> Dict[str, str]: """Parses a permission path into its component segments.""" - m = re.match(r"^tunedModels/(?P.+?)/permissions/(?P.+?)$", path) + m = re.match( + r"^tunedModels/(?P.+?)/permissions/(?P.+?)$", path + ) return m.groupdict() if m else {} @staticmethod - def tuned_model_path(tuned_model: str,) -> str: + def tuned_model_path( + tuned_model: str, + ) -> str: """Returns a fully-qualified tuned_model string.""" - return "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + return "tunedModels/{tuned_model}".format( + tuned_model=tuned_model, + ) @staticmethod - def parse_tuned_model_path(path: str) -> Dict[str,str]: + def parse_tuned_model_path(path: str) -> Dict[str, str]: """Parses a tuned_model path into its component segments.""" m = re.match(r"^tunedModels/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path( + billing_account: str, + ) -> str: """Returns a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path( + folder: str, + ) -> str: """Returns a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format( + folder=folder, + ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path( + organization: str, + ) -> str: """Returns a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format( + organization=organization, + ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path( + project: str, + ) -> str: """Returns a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format( + project=project, + ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path( + project: str, + location: str, + ) -> str: """Returns a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -282,9 +334,13 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_client_cert not in ("true", "false"): - raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) if use_mtls_endpoint not in ("auto", "never", "always"): - raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Figure out the client cert source to use. client_cert_source = None @@ -297,19 +353,23 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio # Figure out which api endpoint to use. if client_options.api_endpoint is not None: api_endpoint = client_options.api_endpoint - elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): api_endpoint = cls.DEFAULT_MTLS_ENDPOINT else: api_endpoint = cls.DEFAULT_ENDPOINT return api_endpoint, client_cert_source - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, PermissionServiceTransport]] = None, - client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, PermissionServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the permission service client. Args: @@ -353,11 +413,15 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() client_options = cast(client_options_lib.ClientOptions, client_options) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options + ) api_key_value = getattr(client_options, "api_key", None) if api_key_value and credentials: - raise ValueError("client_options.api_key and credentials are mutually exclusive") + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -365,8 +429,10 @@ def __init__(self, *, if isinstance(transport, PermissionServiceTransport): # transport is a PermissionServiceTransport instance. if credentials or client_options.credentials_file or api_key_value: - raise ValueError("When providing a transport instance, " - "provide its credentials directly.") + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " @@ -376,8 +442,12 @@ def __init__(self, *, else: import google.auth._default # type: ignore - if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): - credentials = google.auth._default.get_api_key_credentials(api_key_value) + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) Transport = type(self).get_transport_class(transport) self._transport = Transport( @@ -392,15 +462,18 @@ def __init__(self, *, api_audience=client_options.api_audience, ) - def create_permission(self, - request: Optional[Union[permission_service.CreatePermissionRequest, dict]] = None, - *, - parent: Optional[str] = None, - permission: Optional[gag_permission.Permission] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gag_permission.Permission: + def create_permission( + self, + request: Optional[ + Union[permission_service.CreatePermissionRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + permission: Optional[gag_permission.Permission] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: r"""Create a permission to a specific resource. .. code-block:: python @@ -482,8 +555,10 @@ def sample_create_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([parent, permission]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a permission_service.CreatePermissionRequest. @@ -502,12 +577,10 @@ def sample_create_permission(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.create_permission] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("parent", request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -521,14 +594,15 @@ def sample_create_permission(): # Done; return the response. return response - def get_permission(self, - request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> permission.Permission: + def get_permission( + self, + request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission.Permission: r"""Gets information about a specific Permission. .. code-block:: python @@ -608,8 +682,10 @@ def sample_get_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a permission_service.GetPermissionRequest. @@ -626,12 +702,10 @@ def sample_get_permission(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.get_permission] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -645,14 +719,17 @@ def sample_get_permission(): # Done; return the response. return response - def list_permissions(self, - request: Optional[Union[permission_service.ListPermissionsRequest, dict]] = None, - *, - parent: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> pagers.ListPermissionsPager: + def list_permissions( + self, + request: Optional[ + Union[permission_service.ListPermissionsRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListPermissionsPager: r"""Lists permissions for the specific resource. .. code-block:: python @@ -712,8 +789,10 @@ def sample_list_permissions(): # gotten any keyword arguments that map to the request. has_flattened_params = any([parent]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a permission_service.ListPermissionsRequest. @@ -730,12 +809,10 @@ def sample_list_permissions(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.list_permissions] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("parent", request.parent), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), ) # Send the request. @@ -758,15 +835,18 @@ def sample_list_permissions(): # Done; return the response. return response - def update_permission(self, - request: Optional[Union[permission_service.UpdatePermissionRequest, dict]] = None, - *, - permission: Optional[gag_permission.Permission] = None, - update_mask: Optional[field_mask_pb2.FieldMask] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> gag_permission.Permission: + def update_permission( + self, + request: Optional[ + Union[permission_service.UpdatePermissionRequest, dict] + ] = None, + *, + permission: Optional[gag_permission.Permission] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: r"""Updates the permission. .. code-block:: python @@ -852,8 +932,10 @@ def sample_update_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([permission, update_mask]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a permission_service.UpdatePermissionRequest. @@ -872,12 +954,12 @@ def sample_update_permission(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.update_permission] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("permission.name", request.permission.name), - )), + gapic_v1.routing_header.to_grpc_metadata( + (("permission.name", request.permission.name),) + ), ) # Send the request. @@ -891,14 +973,17 @@ def sample_update_permission(): # Done; return the response. return response - def delete_permission(self, - request: Optional[Union[permission_service.DeletePermissionRequest, dict]] = None, - *, - name: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> None: + def delete_permission( + self, + request: Optional[ + Union[permission_service.DeletePermissionRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: r"""Deletes the permission. .. code-block:: python @@ -945,8 +1030,10 @@ def sample_delete_permission(): # gotten any keyword arguments that map to the request. has_flattened_params = any([name]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a permission_service.DeletePermissionRequest. @@ -963,12 +1050,10 @@ def sample_delete_permission(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.delete_permission] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -979,13 +1064,16 @@ def sample_delete_permission(): metadata=metadata, ) - def transfer_ownership(self, - request: Optional[Union[permission_service.TransferOwnershipRequest, dict]] = None, - *, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> permission_service.TransferOwnershipResponse: + def transfer_ownership( + self, + request: Optional[ + Union[permission_service.TransferOwnershipRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission_service.TransferOwnershipResponse: r"""Transfers ownership of the tuned model. This is the only way to change ownership of the tuned model. The current owner will be downgraded to writer @@ -1044,12 +1132,10 @@ def sample_transfer_ownership(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.transfer_ownership] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("name", request.name), - )), + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), ) # Send the request. @@ -1077,18 +1163,9 @@ def __exit__(self, type, value, traceback): self.transport.close() +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) - - - - - - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "PermissionServiceClient", -) +__all__ = ("PermissionServiceClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py similarity index 84% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py index 188d4c7963a6..ae6b80975a9a 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/pagers.py @@ -13,10 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, AsyncIterator, Awaitable, Callable, Sequence, Tuple, Optional, Iterator +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + Sequence, + Tuple, +) -from google.ai.generativelanguage_v1beta3.types import permission -from google.ai.generativelanguage_v1beta3.types import permission_service +from google.ai.generativelanguage_v1beta3.types import permission, permission_service class ListPermissionsPager: @@ -36,12 +44,15 @@ class ListPermissionsPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., permission_service.ListPermissionsResponse], - request: permission_service.ListPermissionsRequest, - response: permission_service.ListPermissionsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., permission_service.ListPermissionsResponse], + request: permission_service.ListPermissionsRequest, + response: permission_service.ListPermissionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiate the pager. Args: @@ -75,7 +86,7 @@ def __iter__(self) -> Iterator[permission.Permission]: yield from page.permissions def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) class ListPermissionsAsyncPager: @@ -95,12 +106,15 @@ class ListPermissionsAsyncPager: attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ - def __init__(self, - method: Callable[..., Awaitable[permission_service.ListPermissionsResponse]], - request: permission_service.ListPermissionsRequest, - response: permission_service.ListPermissionsResponse, - *, - metadata: Sequence[Tuple[str, str]] = ()): + + def __init__( + self, + method: Callable[..., Awaitable[permission_service.ListPermissionsResponse]], + request: permission_service.ListPermissionsRequest, + response: permission_service.ListPermissionsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): """Instantiates the pager. Args: @@ -128,6 +142,7 @@ async def pages(self) -> AsyncIterator[permission_service.ListPermissionsRespons self._request.page_token = self._response.next_page_token self._response = await self._method(self._request, metadata=self._metadata) yield self._response + def __aiter__(self) -> AsyncIterator[permission.Permission]: async def async_generator(): async for page in self.pages: @@ -137,4 +152,4 @@ async def async_generator(): return async_generator() def __repr__(self) -> str: - return '{0}<{1!r}>'.format(self.__class__.__name__, self._response) + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py similarity index 66% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py index 5232d4043c80..fe33568492a6 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/__init__.py @@ -19,20 +19,18 @@ from .base import PermissionServiceTransport from .grpc import PermissionServiceGrpcTransport from .grpc_asyncio import PermissionServiceGrpcAsyncIOTransport -from .rest import PermissionServiceRestTransport -from .rest import PermissionServiceRestInterceptor - +from .rest import PermissionServiceRestInterceptor, PermissionServiceRestTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[PermissionServiceTransport]] -_transport_registry['grpc'] = PermissionServiceGrpcTransport -_transport_registry['grpc_asyncio'] = PermissionServiceGrpcAsyncIOTransport -_transport_registry['rest'] = PermissionServiceRestTransport +_transport_registry["grpc"] = PermissionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PermissionServiceGrpcAsyncIOTransport +_transport_registry["rest"] = PermissionServiceRestTransport __all__ = ( - 'PermissionServiceTransport', - 'PermissionServiceGrpcTransport', - 'PermissionServiceGrpcAsyncIOTransport', - 'PermissionServiceRestTransport', - 'PermissionServiceRestInterceptor', + "PermissionServiceTransport", + "PermissionServiceGrpcTransport", + "PermissionServiceGrpcAsyncIOTransport", + "PermissionServiceRestTransport", + "PermissionServiceRestInterceptor", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py similarity index 63% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py index d0d736d33e11..5c530a8a70b1 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/base.py @@ -16,44 +16,46 @@ import abc from typing import Awaitable, Callable, Dict, Optional, Sequence, Union -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version - -import google.auth # type: ignore import google.api_core from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.protobuf import empty_pb2 # type: ignore -from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission_service -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import empty_pb2 # type: ignore -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) class PermissionServiceTransport(abc.ABC): """Abstract transport class for PermissionService.""" - AUTH_SCOPES = ( - ) + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" - DEFAULT_HOST: str = 'generativelanguage.googleapis.com' def __init__( - self, *, - host: str = DEFAULT_HOST, - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - **kwargs, - ) -> None: + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -87,30 +89,38 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = google.auth.load_credentials_from_file( - credentials_file, - **scopes_kwargs, - quota_project_id=quota_project_id - ) + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) # Don't apply audience if the credentials file passed from user. if hasattr(credentials, "with_gdch_audience"): - credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) # If the credentials are service account credentials, then always try to use self signed JWT. - if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): credentials = credentials.with_always_use_jwt_access(True) # Save the credentials. self._credentials = credentials # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host def _prep_wrapped_messages(self, client_info): @@ -146,69 +156,75 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } + } def close(self): """Closes resources associated with the transport. - .. warning:: - Only call this method if the transport is NOT shared - with other clients - this may cause errors in other clients! + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! """ raise NotImplementedError() @property - def create_permission(self) -> Callable[ - [permission_service.CreatePermissionRequest], - Union[ - gag_permission.Permission, - Awaitable[gag_permission.Permission] - ]]: + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], + Union[gag_permission.Permission, Awaitable[gag_permission.Permission]], + ]: raise NotImplementedError() @property - def get_permission(self) -> Callable[ - [permission_service.GetPermissionRequest], - Union[ - permission.Permission, - Awaitable[permission.Permission] - ]]: + def get_permission( + self, + ) -> Callable[ + [permission_service.GetPermissionRequest], + Union[permission.Permission, Awaitable[permission.Permission]], + ]: raise NotImplementedError() @property - def list_permissions(self) -> Callable[ - [permission_service.ListPermissionsRequest], - Union[ - permission_service.ListPermissionsResponse, - Awaitable[permission_service.ListPermissionsResponse] - ]]: + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + Union[ + permission_service.ListPermissionsResponse, + Awaitable[permission_service.ListPermissionsResponse], + ], + ]: raise NotImplementedError() @property - def update_permission(self) -> Callable[ - [permission_service.UpdatePermissionRequest], - Union[ - gag_permission.Permission, - Awaitable[gag_permission.Permission] - ]]: + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], + Union[gag_permission.Permission, Awaitable[gag_permission.Permission]], + ]: raise NotImplementedError() @property - def delete_permission(self) -> Callable[ - [permission_service.DeletePermissionRequest], - Union[ - empty_pb2.Empty, - Awaitable[empty_pb2.Empty] - ]]: + def delete_permission( + self, + ) -> Callable[ + [permission_service.DeletePermissionRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: raise NotImplementedError() @property - def transfer_ownership(self) -> Callable[ - [permission_service.TransferOwnershipRequest], - Union[ - permission_service.TransferOwnershipResponse, - Awaitable[permission_service.TransferOwnershipResponse] - ]]: + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + Union[ + permission_service.TransferOwnershipResponse, + Awaitable[permission_service.TransferOwnershipResponse], + ], + ]: raise NotImplementedError() @property @@ -216,6 +232,4 @@ def kind(self) -> str: raise NotImplementedError() -__all__ = ( - 'PermissionServiceTransport', -) +__all__ = ("PermissionServiceTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py similarity index 77% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py index b7d0a5160c35..006d82090b49 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc.py @@ -13,23 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import grpc_helpers -from google.api_core import gapic_v1 -import google.auth # type: ignore +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore import grpc # type: ignore -from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission_service -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import empty_pb2 # type: ignore -from .base import PermissionServiceTransport, DEFAULT_CLIENT_INFO + +from .base import DEFAULT_CLIENT_INFO, PermissionServiceTransport class PermissionServiceGrpcTransport(PermissionServiceTransport): @@ -45,23 +44,26 @@ class PermissionServiceGrpcTransport(PermissionServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -180,13 +182,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -221,19 +225,20 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property - def create_permission(self) -> Callable[ - [permission_service.CreatePermissionRequest], - gag_permission.Permission]: + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], gag_permission.Permission + ]: r"""Return a callable for the create permission method over gRPC. Create a permission to a specific resource. @@ -248,18 +253,18 @@ def create_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_permission' not in self._stubs: - self._stubs['create_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/CreatePermission', + if "create_permission" not in self._stubs: + self._stubs["create_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/CreatePermission", request_serializer=permission_service.CreatePermissionRequest.serialize, response_deserializer=gag_permission.Permission.deserialize, ) - return self._stubs['create_permission'] + return self._stubs["create_permission"] @property - def get_permission(self) -> Callable[ - [permission_service.GetPermissionRequest], - permission.Permission]: + def get_permission( + self, + ) -> Callable[[permission_service.GetPermissionRequest], permission.Permission]: r"""Return a callable for the get permission method over gRPC. Gets information about a specific Permission. @@ -274,18 +279,21 @@ def get_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_permission' not in self._stubs: - self._stubs['get_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/GetPermission', + if "get_permission" not in self._stubs: + self._stubs["get_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/GetPermission", request_serializer=permission_service.GetPermissionRequest.serialize, response_deserializer=permission.Permission.deserialize, ) - return self._stubs['get_permission'] + return self._stubs["get_permission"] @property - def list_permissions(self) -> Callable[ - [permission_service.ListPermissionsRequest], - permission_service.ListPermissionsResponse]: + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + permission_service.ListPermissionsResponse, + ]: r"""Return a callable for the list permissions method over gRPC. Lists permissions for the specific resource. @@ -300,18 +308,20 @@ def list_permissions(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_permissions' not in self._stubs: - self._stubs['list_permissions'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/ListPermissions', + if "list_permissions" not in self._stubs: + self._stubs["list_permissions"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/ListPermissions", request_serializer=permission_service.ListPermissionsRequest.serialize, response_deserializer=permission_service.ListPermissionsResponse.deserialize, ) - return self._stubs['list_permissions'] + return self._stubs["list_permissions"] @property - def update_permission(self) -> Callable[ - [permission_service.UpdatePermissionRequest], - gag_permission.Permission]: + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], gag_permission.Permission + ]: r"""Return a callable for the update permission method over gRPC. Updates the permission. @@ -326,18 +336,18 @@ def update_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_permission' not in self._stubs: - self._stubs['update_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/UpdatePermission', + if "update_permission" not in self._stubs: + self._stubs["update_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/UpdatePermission", request_serializer=permission_service.UpdatePermissionRequest.serialize, response_deserializer=gag_permission.Permission.deserialize, ) - return self._stubs['update_permission'] + return self._stubs["update_permission"] @property - def delete_permission(self) -> Callable[ - [permission_service.DeletePermissionRequest], - empty_pb2.Empty]: + def delete_permission( + self, + ) -> Callable[[permission_service.DeletePermissionRequest], empty_pb2.Empty]: r"""Return a callable for the delete permission method over gRPC. Deletes the permission. @@ -352,18 +362,21 @@ def delete_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_permission' not in self._stubs: - self._stubs['delete_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/DeletePermission', + if "delete_permission" not in self._stubs: + self._stubs["delete_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/DeletePermission", request_serializer=permission_service.DeletePermissionRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, ) - return self._stubs['delete_permission'] + return self._stubs["delete_permission"] @property - def transfer_ownership(self) -> Callable[ - [permission_service.TransferOwnershipRequest], - permission_service.TransferOwnershipResponse]: + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + permission_service.TransferOwnershipResponse, + ]: r"""Return a callable for the transfer ownership method over gRPC. Transfers ownership of the tuned model. @@ -381,13 +394,13 @@ def transfer_ownership(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'transfer_ownership' not in self._stubs: - self._stubs['transfer_ownership'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/TransferOwnership', + if "transfer_ownership" not in self._stubs: + self._stubs["transfer_ownership"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/TransferOwnership", request_serializer=permission_service.TransferOwnershipRequest.serialize, response_deserializer=permission_service.TransferOwnershipResponse.deserialize, ) - return self._stubs['transfer_ownership'] + return self._stubs["transfer_ownership"] def close(self): self.grpc_channel.close() @@ -397,6 +410,4 @@ def kind(self) -> str: return "grpc" -__all__ = ( - 'PermissionServiceGrpcTransport', -) +__all__ = ("PermissionServiceGrpcTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py similarity index 77% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py index 3f07d3cf3bd1..707f88eb3482 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/grpc_asyncio.py @@ -13,23 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import gapic_v1, grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore -from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission from google.ai.generativelanguage_v1beta3.types import permission_service -from google.longrunning import operations_pb2 # type: ignore -from google.protobuf import empty_pb2 # type: ignore -from .base import PermissionServiceTransport, DEFAULT_CLIENT_INFO + +from .base import DEFAULT_CLIENT_INFO, PermissionServiceTransport from .grpc import PermissionServiceGrpcTransport @@ -51,13 +50,15 @@ class PermissionServiceGrpcAsyncIOTransport(PermissionServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -88,24 +89,26 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -234,9 +237,12 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def create_permission(self) -> Callable[ - [permission_service.CreatePermissionRequest], - Awaitable[gag_permission.Permission]]: + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], + Awaitable[gag_permission.Permission], + ]: r"""Return a callable for the create permission method over gRPC. Create a permission to a specific resource. @@ -251,18 +257,20 @@ def create_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'create_permission' not in self._stubs: - self._stubs['create_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/CreatePermission', + if "create_permission" not in self._stubs: + self._stubs["create_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/CreatePermission", request_serializer=permission_service.CreatePermissionRequest.serialize, response_deserializer=gag_permission.Permission.deserialize, ) - return self._stubs['create_permission'] + return self._stubs["create_permission"] @property - def get_permission(self) -> Callable[ - [permission_service.GetPermissionRequest], - Awaitable[permission.Permission]]: + def get_permission( + self, + ) -> Callable[ + [permission_service.GetPermissionRequest], Awaitable[permission.Permission] + ]: r"""Return a callable for the get permission method over gRPC. Gets information about a specific Permission. @@ -277,18 +285,21 @@ def get_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'get_permission' not in self._stubs: - self._stubs['get_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/GetPermission', + if "get_permission" not in self._stubs: + self._stubs["get_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/GetPermission", request_serializer=permission_service.GetPermissionRequest.serialize, response_deserializer=permission.Permission.deserialize, ) - return self._stubs['get_permission'] + return self._stubs["get_permission"] @property - def list_permissions(self) -> Callable[ - [permission_service.ListPermissionsRequest], - Awaitable[permission_service.ListPermissionsResponse]]: + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + Awaitable[permission_service.ListPermissionsResponse], + ]: r"""Return a callable for the list permissions method over gRPC. Lists permissions for the specific resource. @@ -303,18 +314,21 @@ def list_permissions(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'list_permissions' not in self._stubs: - self._stubs['list_permissions'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/ListPermissions', + if "list_permissions" not in self._stubs: + self._stubs["list_permissions"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/ListPermissions", request_serializer=permission_service.ListPermissionsRequest.serialize, response_deserializer=permission_service.ListPermissionsResponse.deserialize, ) - return self._stubs['list_permissions'] + return self._stubs["list_permissions"] @property - def update_permission(self) -> Callable[ - [permission_service.UpdatePermissionRequest], - Awaitable[gag_permission.Permission]]: + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], + Awaitable[gag_permission.Permission], + ]: r"""Return a callable for the update permission method over gRPC. Updates the permission. @@ -329,18 +343,20 @@ def update_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'update_permission' not in self._stubs: - self._stubs['update_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/UpdatePermission', + if "update_permission" not in self._stubs: + self._stubs["update_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/UpdatePermission", request_serializer=permission_service.UpdatePermissionRequest.serialize, response_deserializer=gag_permission.Permission.deserialize, ) - return self._stubs['update_permission'] + return self._stubs["update_permission"] @property - def delete_permission(self) -> Callable[ - [permission_service.DeletePermissionRequest], - Awaitable[empty_pb2.Empty]]: + def delete_permission( + self, + ) -> Callable[ + [permission_service.DeletePermissionRequest], Awaitable[empty_pb2.Empty] + ]: r"""Return a callable for the delete permission method over gRPC. Deletes the permission. @@ -355,18 +371,21 @@ def delete_permission(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'delete_permission' not in self._stubs: - self._stubs['delete_permission'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/DeletePermission', + if "delete_permission" not in self._stubs: + self._stubs["delete_permission"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/DeletePermission", request_serializer=permission_service.DeletePermissionRequest.serialize, response_deserializer=empty_pb2.Empty.FromString, ) - return self._stubs['delete_permission'] + return self._stubs["delete_permission"] @property - def transfer_ownership(self) -> Callable[ - [permission_service.TransferOwnershipRequest], - Awaitable[permission_service.TransferOwnershipResponse]]: + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + Awaitable[permission_service.TransferOwnershipResponse], + ]: r"""Return a callable for the transfer ownership method over gRPC. Transfers ownership of the tuned model. @@ -384,18 +403,16 @@ def transfer_ownership(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'transfer_ownership' not in self._stubs: - self._stubs['transfer_ownership'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.PermissionService/TransferOwnership', + if "transfer_ownership" not in self._stubs: + self._stubs["transfer_ownership"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.PermissionService/TransferOwnership", request_serializer=permission_service.TransferOwnershipRequest.serialize, response_deserializer=permission_service.TransferOwnershipResponse.deserialize, ) - return self._stubs['transfer_ownership'] + return self._stubs["transfer_ownership"] def close(self): return self.grpc_channel.close() -__all__ = ( - 'PermissionServiceGrpcAsyncIOTransport', -) +__all__ = ("PermissionServiceGrpcAsyncIOTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py similarity index 70% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py index 913244341d45..12af3b148a45 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/transports/rest.py @@ -14,24 +14,21 @@ # limitations under the License. # -from google.auth.transport.requests import AuthorizedSession # type: ignore +import dataclasses import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, path_template, rest_helpers, rest_streaming from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -from google.api_core import rest_helpers -from google.api_core import rest_streaming -from google.api_core import path_template -from google.api_core import gapic_v1 - +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore from google.protobuf import json_format +import grpc # type: ignore from requests import __version__ as requests_version -import dataclasses -import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import warnings try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] @@ -39,14 +36,15 @@ OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import permission -from google.ai.generativelanguage_v1beta3.types import permission as gag_permission -from google.ai.generativelanguage_v1beta3.types import permission_service -from google.protobuf import empty_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore -from .base import PermissionServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission_service +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .base import PermissionServiceTransport DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, @@ -119,7 +117,12 @@ def post_update_permission(self, response): """ - def pre_create_permission(self, request: permission_service.CreatePermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.CreatePermissionRequest, Sequence[Tuple[str, str]]]: + + def pre_create_permission( + self, + request: permission_service.CreatePermissionRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[permission_service.CreatePermissionRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for create_permission Override in a subclass to manipulate the request or metadata @@ -127,7 +130,9 @@ def pre_create_permission(self, request: permission_service.CreatePermissionRequ """ return request, metadata - def post_create_permission(self, response: gag_permission.Permission) -> gag_permission.Permission: + def post_create_permission( + self, response: gag_permission.Permission + ) -> gag_permission.Permission: """Post-rpc interceptor for create_permission Override in a subclass to manipulate the response @@ -135,7 +140,12 @@ def post_create_permission(self, response: gag_permission.Permission) -> gag_per it is returned to user code. """ return response - def pre_delete_permission(self, request: permission_service.DeletePermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.DeletePermissionRequest, Sequence[Tuple[str, str]]]: + + def pre_delete_permission( + self, + request: permission_service.DeletePermissionRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[permission_service.DeletePermissionRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for delete_permission Override in a subclass to manipulate the request or metadata @@ -143,7 +153,11 @@ def pre_delete_permission(self, request: permission_service.DeletePermissionRequ """ return request, metadata - def pre_get_permission(self, request: permission_service.GetPermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.GetPermissionRequest, Sequence[Tuple[str, str]]]: + def pre_get_permission( + self, + request: permission_service.GetPermissionRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[permission_service.GetPermissionRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for get_permission Override in a subclass to manipulate the request or metadata @@ -151,7 +165,9 @@ def pre_get_permission(self, request: permission_service.GetPermissionRequest, m """ return request, metadata - def post_get_permission(self, response: permission.Permission) -> permission.Permission: + def post_get_permission( + self, response: permission.Permission + ) -> permission.Permission: """Post-rpc interceptor for get_permission Override in a subclass to manipulate the response @@ -159,7 +175,12 @@ def post_get_permission(self, response: permission.Permission) -> permission.Per it is returned to user code. """ return response - def pre_list_permissions(self, request: permission_service.ListPermissionsRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.ListPermissionsRequest, Sequence[Tuple[str, str]]]: + + def pre_list_permissions( + self, + request: permission_service.ListPermissionsRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[permission_service.ListPermissionsRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for list_permissions Override in a subclass to manipulate the request or metadata @@ -167,7 +188,9 @@ def pre_list_permissions(self, request: permission_service.ListPermissionsReques """ return request, metadata - def post_list_permissions(self, response: permission_service.ListPermissionsResponse) -> permission_service.ListPermissionsResponse: + def post_list_permissions( + self, response: permission_service.ListPermissionsResponse + ) -> permission_service.ListPermissionsResponse: """Post-rpc interceptor for list_permissions Override in a subclass to manipulate the response @@ -175,7 +198,12 @@ def post_list_permissions(self, response: permission_service.ListPermissionsResp it is returned to user code. """ return response - def pre_transfer_ownership(self, request: permission_service.TransferOwnershipRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.TransferOwnershipRequest, Sequence[Tuple[str, str]]]: + + def pre_transfer_ownership( + self, + request: permission_service.TransferOwnershipRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[permission_service.TransferOwnershipRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for transfer_ownership Override in a subclass to manipulate the request or metadata @@ -183,7 +211,9 @@ def pre_transfer_ownership(self, request: permission_service.TransferOwnershipRe """ return request, metadata - def post_transfer_ownership(self, response: permission_service.TransferOwnershipResponse) -> permission_service.TransferOwnershipResponse: + def post_transfer_ownership( + self, response: permission_service.TransferOwnershipResponse + ) -> permission_service.TransferOwnershipResponse: """Post-rpc interceptor for transfer_ownership Override in a subclass to manipulate the response @@ -191,7 +221,12 @@ def post_transfer_ownership(self, response: permission_service.TransferOwnership it is returned to user code. """ return response - def pre_update_permission(self, request: permission_service.UpdatePermissionRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[permission_service.UpdatePermissionRequest, Sequence[Tuple[str, str]]]: + + def pre_update_permission( + self, + request: permission_service.UpdatePermissionRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[permission_service.UpdatePermissionRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for update_permission Override in a subclass to manipulate the request or metadata @@ -199,7 +234,9 @@ def pre_update_permission(self, request: permission_service.UpdatePermissionRequ """ return request, metadata - def post_update_permission(self, response: gag_permission.Permission) -> gag_permission.Permission: + def post_update_permission( + self, response: gag_permission.Permission + ) -> gag_permission.Permission: """Post-rpc interceptor for update_permission Override in a subclass to manipulate the response @@ -230,20 +267,21 @@ class PermissionServiceRestTransport(PermissionServiceTransport): """ - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - client_cert_source_for_mtls: Optional[Callable[[ - ], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - url_scheme: str = 'https', - interceptor: Optional[PermissionServiceRestInterceptor] = None, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[PermissionServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -282,7 +320,9 @@ def __init__(self, *, # credentials object maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) if maybe_url_match is None: - raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER url_match_items = maybe_url_match.groupdict() @@ -293,10 +333,11 @@ def __init__(self, *, credentials=credentials, client_info=client_info, always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience + api_audience=api_audience, ) self._session = AuthorizedSession( - self._credentials, default_host=self.DEFAULT_HOST) + self._credentials, default_host=self.DEFAULT_HOST + ) if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) self._interceptor = interceptor or PermissionServiceRestInterceptor() @@ -306,19 +347,24 @@ class _CreatePermission(PermissionServiceRestStub): def __hash__(self): return hash("CreatePermission") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: permission_service.CreatePermissionRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> gag_permission.Permission: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: permission_service.CreatePermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: r"""Call the create permission method over HTTP. Args: @@ -358,46 +404,51 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{parent=tunedModels/*}/permissions', - 'body': 'permission', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{parent=tunedModels/*}/permissions", + "body": "permission", + }, ] - request, metadata = self._interceptor.pre_create_permission(request, metadata) + request, metadata = self._interceptor.pre_create_permission( + request, metadata + ) pb_request = permission_service.CreatePermissionRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -416,19 +467,24 @@ class _DeletePermission(PermissionServiceRestStub): def __hash__(self): return hash("DeletePermission") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: permission_service.DeletePermissionRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: permission_service.DeletePermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ): r"""Call the delete permission method over HTTP. Args: @@ -441,37 +497,42 @@ def __call__(self, sent along with the request as metadata. """ - http_options: List[Dict[str, str]] = [{ - 'method': 'delete', - 'uri': '/v1beta3/{name=tunedModels/*/permissions/*}', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1beta3/{name=tunedModels/*/permissions/*}", + }, ] - request, metadata = self._interceptor.pre_delete_permission(request, metadata) + request, metadata = self._interceptor.pre_delete_permission( + request, metadata + ) pb_request = permission_service.DeletePermissionRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -482,19 +543,24 @@ class _GetPermission(PermissionServiceRestStub): def __hash__(self): return hash("GetPermission") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: permission_service.GetPermissionRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> permission.Permission: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: permission_service.GetPermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission.Permission: r"""Call the get permission method over HTTP. Args: @@ -535,37 +601,40 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta3/{name=tunedModels/*/permissions/*}', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta3/{name=tunedModels/*/permissions/*}", + }, ] request, metadata = self._interceptor.pre_get_permission(request, metadata) pb_request = permission_service.GetPermissionRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -584,19 +653,24 @@ class _ListPermissions(PermissionServiceRestStub): def __hash__(self): return hash("ListPermissions") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: permission_service.ListPermissionsRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> permission_service.ListPermissionsResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: permission_service.ListPermissionsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission_service.ListPermissionsResponse: r"""Call the list permissions method over HTTP. Args: @@ -615,37 +689,42 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'get', - 'uri': '/v1beta3/{parent=tunedModels/*}/permissions', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1beta3/{parent=tunedModels/*}/permissions", + }, ] - request, metadata = self._interceptor.pre_list_permissions(request, metadata) + request, metadata = self._interceptor.pre_list_permissions( + request, metadata + ) pb_request = permission_service.ListPermissionsRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -664,19 +743,24 @@ class _TransferOwnership(PermissionServiceRestStub): def __hash__(self): return hash("TransferOwnership") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: permission_service.TransferOwnershipRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> permission_service.TransferOwnershipResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: permission_service.TransferOwnershipRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> permission_service.TransferOwnershipResponse: r"""Call the transfer ownership method over HTTP. Args: @@ -694,46 +778,51 @@ def __call__(self, Response from ``TransferOwnership``. """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{name=tunedModels/*}:transferOwnership', - 'body': '*', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{name=tunedModels/*}:transferOwnership", + "body": "*", + }, ] - request, metadata = self._interceptor.pre_transfer_ownership(request, metadata) + request, metadata = self._interceptor.pre_transfer_ownership( + request, metadata + ) pb_request = permission_service.TransferOwnershipRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -752,19 +841,26 @@ class _UpdatePermission(PermissionServiceRestStub): def __hash__(self): return hash("UpdatePermission") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - "updateMask" : {}, } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask": {}, + } @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: permission_service.UpdatePermissionRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> gag_permission.Permission: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: permission_service.UpdatePermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gag_permission.Permission: r"""Call the update permission method over HTTP. Args: @@ -804,46 +900,51 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'patch', - 'uri': '/v1beta3/{permission.name=tunedModels/*/permissions/*}', - 'body': 'permission', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1beta3/{permission.name=tunedModels/*/permissions/*}", + "body": "permission", + }, ] - request, metadata = self._interceptor.pre_update_permission(request, metadata) + request, metadata = self._interceptor.pre_update_permission( + request, metadata + ) pb_request = permission_service.UpdatePermissionRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -859,52 +960,62 @@ def __call__(self, return resp @property - def create_permission(self) -> Callable[ - [permission_service.CreatePermissionRequest], - gag_permission.Permission]: + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], gag_permission.Permission + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._CreatePermission(self._session, self._host, self._interceptor) # type: ignore + return self._CreatePermission(self._session, self._host, self._interceptor) # type: ignore @property - def delete_permission(self) -> Callable[ - [permission_service.DeletePermissionRequest], - empty_pb2.Empty]: + def delete_permission( + self, + ) -> Callable[[permission_service.DeletePermissionRequest], empty_pb2.Empty]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._DeletePermission(self._session, self._host, self._interceptor) # type: ignore + return self._DeletePermission(self._session, self._host, self._interceptor) # type: ignore @property - def get_permission(self) -> Callable[ - [permission_service.GetPermissionRequest], - permission.Permission]: + def get_permission( + self, + ) -> Callable[[permission_service.GetPermissionRequest], permission.Permission]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._GetPermission(self._session, self._host, self._interceptor) # type: ignore + return self._GetPermission(self._session, self._host, self._interceptor) # type: ignore @property - def list_permissions(self) -> Callable[ - [permission_service.ListPermissionsRequest], - permission_service.ListPermissionsResponse]: + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + permission_service.ListPermissionsResponse, + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._ListPermissions(self._session, self._host, self._interceptor) # type: ignore + return self._ListPermissions(self._session, self._host, self._interceptor) # type: ignore @property - def transfer_ownership(self) -> Callable[ - [permission_service.TransferOwnershipRequest], - permission_service.TransferOwnershipResponse]: + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + permission_service.TransferOwnershipResponse, + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._TransferOwnership(self._session, self._host, self._interceptor) # type: ignore + return self._TransferOwnership(self._session, self._host, self._interceptor) # type: ignore @property - def update_permission(self) -> Callable[ - [permission_service.UpdatePermissionRequest], - gag_permission.Permission]: + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], gag_permission.Permission + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._UpdatePermission(self._session, self._host, self._interceptor) # type: ignore + return self._UpdatePermission(self._session, self._host, self._interceptor) # type: ignore @property def kind(self) -> str: @@ -914,6 +1025,4 @@ def close(self): self._session.close() -__all__=( - 'PermissionServiceRestTransport', -) +__all__ = ("PermissionServiceRestTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py similarity index 92% rename from owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py index f167a9c3175d..f705e582e7a1 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/__init__.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .client import TextServiceClient from .async_client import TextServiceAsyncClient +from .client import TextServiceClient __all__ = ( - 'TextServiceClient', - 'TextServiceAsyncClient', + "TextServiceClient", + "TextServiceAsyncClient", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py similarity index 84% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py index f1be99ee3ab2..bd669bc09a57 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/async_client.py @@ -16,28 +16,39 @@ from collections import OrderedDict import functools import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) -from google.api_core.client_options import ClientOptions from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import safety -from google.ai.generativelanguage_v1beta3.types import text_service -from google.longrunning import operations_pb2 # type: ignore -from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO -from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1beta3.types import safety, text_service + from .client import TextServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, TextServiceTransport +from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport class TextServiceAsyncClient: @@ -54,16 +65,26 @@ class TextServiceAsyncClient: model_path = staticmethod(TextServiceClient.model_path) parse_model_path = staticmethod(TextServiceClient.parse_model_path) - common_billing_account_path = staticmethod(TextServiceClient.common_billing_account_path) - parse_common_billing_account_path = staticmethod(TextServiceClient.parse_common_billing_account_path) + common_billing_account_path = staticmethod( + TextServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + TextServiceClient.parse_common_billing_account_path + ) common_folder_path = staticmethod(TextServiceClient.common_folder_path) parse_common_folder_path = staticmethod(TextServiceClient.parse_common_folder_path) common_organization_path = staticmethod(TextServiceClient.common_organization_path) - parse_common_organization_path = staticmethod(TextServiceClient.parse_common_organization_path) + parse_common_organization_path = staticmethod( + TextServiceClient.parse_common_organization_path + ) common_project_path = staticmethod(TextServiceClient.common_project_path) - parse_common_project_path = staticmethod(TextServiceClient.parse_common_project_path) + parse_common_project_path = staticmethod( + TextServiceClient.parse_common_project_path + ) common_location_path = staticmethod(TextServiceClient.common_location_path) - parse_common_location_path = staticmethod(TextServiceClient.parse_common_location_path) + parse_common_location_path = staticmethod( + TextServiceClient.parse_common_location_path + ) @classmethod def from_service_account_info(cls, info: dict, *args, **kwargs): @@ -99,7 +120,9 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): from_service_account_json = from_service_account_file @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -141,14 +164,18 @@ def transport(self) -> TextServiceTransport: """ return self._client.transport - get_transport_class = functools.partial(type(TextServiceClient).get_transport_class, type(TextServiceClient)) - - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Union[str, TextServiceTransport] = "grpc_asyncio", - client_options: Optional[ClientOptions] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + get_transport_class = functools.partial( + type(TextServiceClient).get_transport_class, type(TextServiceClient) + ) + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Union[str, TextServiceTransport] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the text service client. Args: @@ -186,23 +213,23 @@ def __init__(self, *, transport=transport, client_options=client_options, client_info=client_info, - ) - async def generate_text(self, - request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[text_service.TextPrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - max_output_tokens: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.GenerateTextResponse: + async def generate_text( + self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.GenerateTextResponse: r"""Generates a response from the model given an input message. @@ -343,10 +370,22 @@ async def sample_generate_text(): # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) + has_flattened_params = any( + [ + model, + prompt, + temperature, + candidate_count, + max_output_tokens, + top_p, + top_k, + ] + ) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = text_service.GenerateTextRequest(request) @@ -378,9 +417,7 @@ async def sample_generate_text(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -394,15 +431,16 @@ async def sample_generate_text(): # Done; return the response. return response - async def embed_text(self, - request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, - *, - model: Optional[str] = None, - text: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.EmbedTextResponse: + async def embed_text( + self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.EmbedTextResponse: r"""Generates an embedding from the model given an input message. @@ -467,8 +505,10 @@ async def sample_embed_text(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, text]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = text_service.EmbedTextRequest(request) @@ -490,9 +530,7 @@ async def sample_embed_text(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -506,15 +544,16 @@ async def sample_embed_text(): # Done; return the response. return response - async def batch_embed_text(self, - request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, - *, - model: Optional[str] = None, - texts: Optional[MutableSequence[str]] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.BatchEmbedTextResponse: + async def batch_embed_text( + self, + request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + texts: Optional[MutableSequence[str]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.BatchEmbedTextResponse: r"""Generates multiple embeddings from the model given input text in a synchronous call. @@ -582,8 +621,10 @@ async def sample_batch_embed_text(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, texts]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = text_service.BatchEmbedTextRequest(request) @@ -605,9 +646,7 @@ async def sample_batch_embed_text(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -621,15 +660,16 @@ async def sample_batch_embed_text(): # Done; return the response. return response - async def count_text_tokens(self, - request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[text_service.TextPrompt] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.CountTextTokensResponse: + async def count_text_tokens( + self, + request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.CountTextTokensResponse: r"""Runs a model's tokenizer on a text and returns the token count. @@ -707,8 +747,10 @@ async def sample_count_text_tokens(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, prompt]) if request is not None and has_flattened_params: - raise ValueError("If the `request` argument is set, then none of " - "the individual field arguments should be set.") + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) request = text_service.CountTextTokensRequest(request) @@ -730,9 +772,7 @@ async def sample_count_text_tokens(): # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -752,9 +792,10 @@ async def __aenter__(self) -> "TextServiceAsyncClient": async def __aexit__(self, exc_type, exc, tb): await self.transport.close() -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - -__all__ = ( - "TextServiceAsyncClient", +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ ) + + +__all__ = ("TextServiceAsyncClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/client.py similarity index 82% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/client.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/client.py index ed05565059cc..09ef955f3bf1 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/client.py @@ -16,29 +16,41 @@ from collections import OrderedDict import os import re -from typing import Dict, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, Union, cast - -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version +from typing import ( + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) from google.api_core import client_options as client_options_lib from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries -from google.auth import credentials as ga_credentials # type: ignore -from google.auth.transport import mtls # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth.exceptions import MutualTLSChannelError # type: ignore -from google.oauth2 import service_account # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] except AttributeError: # pragma: NO COVER OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import safety -from google.ai.generativelanguage_v1beta3.types import text_service -from google.longrunning import operations_pb2 # type: ignore -from .transports.base import TextServiceTransport, DEFAULT_CLIENT_INFO +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1beta3.types import safety, text_service + +from .transports.base import DEFAULT_CLIENT_INFO, TextServiceTransport from .transports.grpc import TextServiceGrpcTransport from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport from .transports.rest import TextServiceRestTransport @@ -51,14 +63,16 @@ class TextServiceClientMeta(type): support objects (e.g. transport) without polluting the client instance objects. """ + _transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] _transport_registry["grpc"] = TextServiceGrpcTransport _transport_registry["grpc_asyncio"] = TextServiceGrpcAsyncIOTransport _transport_registry["rest"] = TextServiceRestTransport - def get_transport_class(cls, - label: Optional[str] = None, - ) -> Type[TextServiceTransport]: + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[TextServiceTransport]: """Returns an appropriate transport class. Args: @@ -150,8 +164,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): Returns: TextServiceClient: The constructed client. """ - credentials = service_account.Credentials.from_service_account_file( - filename) + credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials return cls(*args, **kwargs) @@ -168,73 +181,101 @@ def transport(self) -> TextServiceTransport: return self._transport @staticmethod - def model_path(model: str,) -> str: + def model_path( + model: str, + ) -> str: """Returns a fully-qualified model string.""" - return "models/{model}".format(model=model, ) + return "models/{model}".format( + model=model, + ) @staticmethod - def parse_model_path(path: str) -> Dict[str,str]: + def parse_model_path(path: str) -> Dict[str, str]: """Parses a model path into its component segments.""" m = re.match(r"^models/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_billing_account_path(billing_account: str, ) -> str: + def common_billing_account_path( + billing_account: str, + ) -> str: """Returns a fully-qualified billing_account string.""" - return "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) @staticmethod - def parse_common_billing_account_path(path: str) -> Dict[str,str]: + def parse_common_billing_account_path(path: str) -> Dict[str, str]: """Parse a billing_account path into its component segments.""" m = re.match(r"^billingAccounts/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_folder_path(folder: str, ) -> str: + def common_folder_path( + folder: str, + ) -> str: """Returns a fully-qualified folder string.""" - return "folders/{folder}".format(folder=folder, ) + return "folders/{folder}".format( + folder=folder, + ) @staticmethod - def parse_common_folder_path(path: str) -> Dict[str,str]: + def parse_common_folder_path(path: str) -> Dict[str, str]: """Parse a folder path into its component segments.""" m = re.match(r"^folders/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_organization_path(organization: str, ) -> str: + def common_organization_path( + organization: str, + ) -> str: """Returns a fully-qualified organization string.""" - return "organizations/{organization}".format(organization=organization, ) + return "organizations/{organization}".format( + organization=organization, + ) @staticmethod - def parse_common_organization_path(path: str) -> Dict[str,str]: + def parse_common_organization_path(path: str) -> Dict[str, str]: """Parse a organization path into its component segments.""" m = re.match(r"^organizations/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_project_path(project: str, ) -> str: + def common_project_path( + project: str, + ) -> str: """Returns a fully-qualified project string.""" - return "projects/{project}".format(project=project, ) + return "projects/{project}".format( + project=project, + ) @staticmethod - def parse_common_project_path(path: str) -> Dict[str,str]: + def parse_common_project_path(path: str) -> Dict[str, str]: """Parse a project path into its component segments.""" m = re.match(r"^projects/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod - def common_location_path(project: str, location: str, ) -> str: + def common_location_path( + project: str, + location: str, + ) -> str: """Returns a fully-qualified location string.""" - return "projects/{project}/locations/{location}".format(project=project, location=location, ) + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) @staticmethod - def parse_common_location_path(path: str) -> Dict[str,str]: + def parse_common_location_path(path: str) -> Dict[str, str]: """Parse a location path into its component segments.""" m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) return m.groupdict() if m else {} @classmethod - def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_options_lib.ClientOptions] = None): + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): """Return the API endpoint and client cert source for mutual TLS. The client cert source is determined in the following order: @@ -270,9 +311,13 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") if use_client_cert not in ("true", "false"): - raise ValueError("Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`") + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) if use_mtls_endpoint not in ("auto", "never", "always"): - raise MutualTLSChannelError("Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`") + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) # Figure out the client cert source to use. client_cert_source = None @@ -285,19 +330,23 @@ def get_mtls_endpoint_and_cert_source(cls, client_options: Optional[client_optio # Figure out which api endpoint to use. if client_options.api_endpoint is not None: api_endpoint = client_options.api_endpoint - elif use_mtls_endpoint == "always" or (use_mtls_endpoint == "auto" and client_cert_source): + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): api_endpoint = cls.DEFAULT_MTLS_ENDPOINT else: api_endpoint = cls.DEFAULT_ENDPOINT return api_endpoint, client_cert_source - def __init__(self, *, - credentials: Optional[ga_credentials.Credentials] = None, - transport: Optional[Union[str, TextServiceTransport]] = None, - client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - ) -> None: + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[Union[str, TextServiceTransport]] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: """Instantiates the text service client. Args: @@ -341,11 +390,15 @@ def __init__(self, *, client_options = client_options_lib.ClientOptions() client_options = cast(client_options_lib.ClientOptions, client_options) - api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source(client_options) + api_endpoint, client_cert_source_func = self.get_mtls_endpoint_and_cert_source( + client_options + ) api_key_value = getattr(client_options, "api_key", None) if api_key_value and credentials: - raise ValueError("client_options.api_key and credentials are mutually exclusive") + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) # Save or instantiate the transport. # Ordinarily, we provide the transport, but allowing a custom transport @@ -353,8 +406,10 @@ def __init__(self, *, if isinstance(transport, TextServiceTransport): # transport is a TextServiceTransport instance. if credentials or client_options.credentials_file or api_key_value: - raise ValueError("When providing a transport instance, " - "provide its credentials directly.") + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) if client_options.scopes: raise ValueError( "When providing a transport instance, provide its scopes " @@ -364,8 +419,12 @@ def __init__(self, *, else: import google.auth._default # type: ignore - if api_key_value and hasattr(google.auth._default, "get_api_key_credentials"): - credentials = google.auth._default.get_api_key_credentials(api_key_value) + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) Transport = type(self).get_transport_class(transport) self._transport = Transport( @@ -380,20 +439,21 @@ def __init__(self, *, api_audience=client_options.api_audience, ) - def generate_text(self, - request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[text_service.TextPrompt] = None, - temperature: Optional[float] = None, - candidate_count: Optional[int] = None, - max_output_tokens: Optional[int] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.GenerateTextResponse: + def generate_text( + self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.GenerateTextResponse: r"""Generates a response from the model given an input message. @@ -534,10 +594,22 @@ def sample_generate_text(): # Create or coerce a protobuf request object. # Quick check: If we got a request object, we should *not* have # gotten any keyword arguments that map to the request. - has_flattened_params = any([model, prompt, temperature, candidate_count, max_output_tokens, top_p, top_k]) + has_flattened_params = any( + [ + model, + prompt, + temperature, + candidate_count, + max_output_tokens, + top_p, + top_k, + ] + ) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a text_service.GenerateTextRequest. @@ -566,12 +638,10 @@ def sample_generate_text(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.generate_text] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -585,15 +655,16 @@ def sample_generate_text(): # Done; return the response. return response - def embed_text(self, - request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, - *, - model: Optional[str] = None, - text: Optional[str] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.EmbedTextResponse: + def embed_text( + self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.EmbedTextResponse: r"""Generates an embedding from the model given an input message. @@ -658,8 +729,10 @@ def sample_embed_text(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, text]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a text_service.EmbedTextRequest. @@ -678,12 +751,10 @@ def sample_embed_text(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.embed_text] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -697,15 +768,16 @@ def sample_embed_text(): # Done; return the response. return response - def batch_embed_text(self, - request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, - *, - model: Optional[str] = None, - texts: Optional[MutableSequence[str]] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.BatchEmbedTextResponse: + def batch_embed_text( + self, + request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + texts: Optional[MutableSequence[str]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.BatchEmbedTextResponse: r"""Generates multiple embeddings from the model given input text in a synchronous call. @@ -773,8 +845,10 @@ def sample_batch_embed_text(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, texts]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a text_service.BatchEmbedTextRequest. @@ -793,12 +867,10 @@ def sample_batch_embed_text(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.batch_embed_text] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -812,15 +884,16 @@ def sample_batch_embed_text(): # Done; return the response. return response - def count_text_tokens(self, - request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, - *, - model: Optional[str] = None, - prompt: Optional[text_service.TextPrompt] = None, - retry: OptionalRetry = gapic_v1.method.DEFAULT, - timeout: Union[float, object] = gapic_v1.method.DEFAULT, - metadata: Sequence[Tuple[str, str]] = (), - ) -> text_service.CountTextTokensResponse: + def count_text_tokens( + self, + request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.CountTextTokensResponse: r"""Runs a model's tokenizer on a text and returns the token count. @@ -898,8 +971,10 @@ def sample_count_text_tokens(): # gotten any keyword arguments that map to the request. has_flattened_params = any([model, prompt]) if request is not None and has_flattened_params: - raise ValueError('If the `request` argument is set, then none of ' - 'the individual field arguments should be set.') + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) # Minor optimization to avoid making a copy if the user passes # in a text_service.CountTextTokensRequest. @@ -918,12 +993,10 @@ def sample_count_text_tokens(): # and friendly error handling. rpc = self._transport._wrapped_methods[self._transport.count_text_tokens] - # Certain fields should be provided within the metadata header; + # Certain fields should be provided within the metadata header; # add these here. metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ("model", request.model), - )), + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), ) # Send the request. @@ -951,18 +1024,9 @@ def __exit__(self, type, value, traceback): self.transport.close() +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) - - - - - - - -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) - - -__all__ = ( - "TextServiceClient", -) +__all__ = ("TextServiceClient",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py similarity index 68% rename from owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py index 71e949c7a4f5..63721cb6cb66 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta2/google/ai/generativelanguage_v1beta2/services/text_service/transports/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/__init__.py @@ -19,20 +19,18 @@ from .base import TextServiceTransport from .grpc import TextServiceGrpcTransport from .grpc_asyncio import TextServiceGrpcAsyncIOTransport -from .rest import TextServiceRestTransport -from .rest import TextServiceRestInterceptor - +from .rest import TextServiceRestInterceptor, TextServiceRestTransport # Compile a registry of transports. _transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] -_transport_registry['grpc'] = TextServiceGrpcTransport -_transport_registry['grpc_asyncio'] = TextServiceGrpcAsyncIOTransport -_transport_registry['rest'] = TextServiceRestTransport +_transport_registry["grpc"] = TextServiceGrpcTransport +_transport_registry["grpc_asyncio"] = TextServiceGrpcAsyncIOTransport +_transport_registry["rest"] = TextServiceRestTransport __all__ = ( - 'TextServiceTransport', - 'TextServiceGrpcTransport', - 'TextServiceGrpcAsyncIOTransport', - 'TextServiceRestTransport', - 'TextServiceRestInterceptor', + "TextServiceTransport", + "TextServiceGrpcTransport", + "TextServiceGrpcAsyncIOTransport", + "TextServiceRestTransport", + "TextServiceRestInterceptor", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py similarity index 63% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py index ab8ddc3e423d..47e7da1e8ce8 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/base.py @@ -16,41 +16,43 @@ import abc from typing import Awaitable, Callable, Dict, Optional, Sequence, Union -from google.ai.generativelanguage_v1beta3 import gapic_version as package_version - -import google.auth # type: ignore import google.api_core from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 from google.api_core import retry as retries +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore -from google.oauth2 import service_account # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.ai.generativelanguage_v1beta3 import gapic_version as package_version from google.ai.generativelanguage_v1beta3.types import text_service -from google.longrunning import operations_pb2 # type: ignore -DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(gapic_version=package_version.__version__) +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) class TextServiceTransport(abc.ABC): """Abstract transport class for TextService.""" - AUTH_SCOPES = ( - ) + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" - DEFAULT_HOST: str = 'generativelanguage.googleapis.com' def __init__( - self, *, - host: str = DEFAULT_HOST, - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - **kwargs, - ) -> None: + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: """Instantiate the transport. Args: @@ -84,30 +86,38 @@ def __init__( # If no credentials are provided, then determine the appropriate # defaults. if credentials and credentials_file: - raise core_exceptions.DuplicateCredentialArgs("'credentials_file' and 'credentials' are mutually exclusive") + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) if credentials_file is not None: credentials, _ = google.auth.load_credentials_from_file( - credentials_file, - **scopes_kwargs, - quota_project_id=quota_project_id - ) + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) elif credentials is None: - credentials, _ = google.auth.default(**scopes_kwargs, quota_project_id=quota_project_id) + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) # Don't apply audience if the credentials file passed from user. if hasattr(credentials, "with_gdch_audience"): - credentials = credentials.with_gdch_audience(api_audience if api_audience else host) + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) # If the credentials are service account credentials, then always try to use self signed JWT. - if always_use_jwt_access and isinstance(credentials, service_account.Credentials) and hasattr(service_account.Credentials, "with_always_use_jwt_access"): + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): credentials = credentials.with_always_use_jwt_access(True) # Save the credentials. self._credentials = credentials # Save the hostname. Default to port 443 (HTTPS) if none is specified. - if ':' not in host: - host += ':443' + if ":" not in host: + host += ":443" self._host = host def _prep_wrapped_messages(self, client_info): @@ -133,51 +143,62 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), - } + } def close(self): """Closes resources associated with the transport. - .. warning:: - Only call this method if the transport is NOT shared - with other clients - this may cause errors in other clients! + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! """ raise NotImplementedError() @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - Union[ - text_service.GenerateTextResponse, - Awaitable[text_service.GenerateTextResponse] - ]]: + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], + Union[ + text_service.GenerateTextResponse, + Awaitable[text_service.GenerateTextResponse], + ], + ]: raise NotImplementedError() @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - Union[ - text_service.EmbedTextResponse, - Awaitable[text_service.EmbedTextResponse] - ]]: + def embed_text( + self, + ) -> Callable[ + [text_service.EmbedTextRequest], + Union[ + text_service.EmbedTextResponse, Awaitable[text_service.EmbedTextResponse] + ], + ]: raise NotImplementedError() @property - def batch_embed_text(self) -> Callable[ - [text_service.BatchEmbedTextRequest], - Union[ - text_service.BatchEmbedTextResponse, - Awaitable[text_service.BatchEmbedTextResponse] - ]]: + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], + Union[ + text_service.BatchEmbedTextResponse, + Awaitable[text_service.BatchEmbedTextResponse], + ], + ]: raise NotImplementedError() @property - def count_text_tokens(self) -> Callable[ - [text_service.CountTextTokensRequest], - Union[ - text_service.CountTextTokensResponse, - Awaitable[text_service.CountTextTokensResponse] - ]]: + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], + Union[ + text_service.CountTextTokensResponse, + Awaitable[text_service.CountTextTokensResponse], + ], + ]: raise NotImplementedError() @property @@ -185,6 +206,4 @@ def kind(self) -> str: raise NotImplementedError() -__all__ = ( - 'TextServiceTransport', -) +__all__ = ("TextServiceTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py similarity index 79% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py index d3b0615ad633..6bb96215da87 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc.py @@ -13,20 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import grpc_helpers -from google.api_core import gapic_v1 -import google.auth # type: ignore +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - +from google.longrunning import operations_pb2 # type: ignore import grpc # type: ignore from google.ai.generativelanguage_v1beta3.types import text_service -from google.longrunning import operations_pb2 # type: ignore -from .base import TextServiceTransport, DEFAULT_CLIENT_INFO + +from .base import DEFAULT_CLIENT_INFO, TextServiceTransport class TextServiceGrpcTransport(TextServiceTransport): @@ -44,23 +43,26 @@ class TextServiceGrpcTransport(TextServiceTransport): It sends protocol buffers over the wire using gRPC (which is built on top of HTTP/2); the ``grpcio`` package must be installed. """ + _stubs: Dict[str, Callable] - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[grpc.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[grpc.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -179,13 +181,15 @@ def __init__(self, *, self._prep_wrapped_messages(client_info) @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> grpc.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: """Create and return a gRPC channel object. Args: host (Optional[str]): The host for the channel to use. @@ -220,19 +224,20 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service. - """ + """Return the channel designed to connect to this service.""" return self._grpc_channel @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - text_service.GenerateTextResponse]: + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], text_service.GenerateTextResponse + ]: r"""Return a callable for the generate text method over gRPC. Generates a response from the model given an input @@ -248,18 +253,18 @@ def generate_text(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'generate_text' not in self._stubs: - self._stubs['generate_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/GenerateText', + if "generate_text" not in self._stubs: + self._stubs["generate_text"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/GenerateText", request_serializer=text_service.GenerateTextRequest.serialize, response_deserializer=text_service.GenerateTextResponse.deserialize, ) - return self._stubs['generate_text'] + return self._stubs["generate_text"] @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - text_service.EmbedTextResponse]: + def embed_text( + self, + ) -> Callable[[text_service.EmbedTextRequest], text_service.EmbedTextResponse]: r"""Return a callable for the embed text method over gRPC. Generates an embedding from the model given an input @@ -275,18 +280,20 @@ def embed_text(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'embed_text' not in self._stubs: - self._stubs['embed_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/EmbedText', + if "embed_text" not in self._stubs: + self._stubs["embed_text"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/EmbedText", request_serializer=text_service.EmbedTextRequest.serialize, response_deserializer=text_service.EmbedTextResponse.deserialize, ) - return self._stubs['embed_text'] + return self._stubs["embed_text"] @property - def batch_embed_text(self) -> Callable[ - [text_service.BatchEmbedTextRequest], - text_service.BatchEmbedTextResponse]: + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], text_service.BatchEmbedTextResponse + ]: r"""Return a callable for the batch embed text method over gRPC. Generates multiple embeddings from the model given @@ -302,18 +309,20 @@ def batch_embed_text(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_embed_text' not in self._stubs: - self._stubs['batch_embed_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/BatchEmbedText', + if "batch_embed_text" not in self._stubs: + self._stubs["batch_embed_text"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/BatchEmbedText", request_serializer=text_service.BatchEmbedTextRequest.serialize, response_deserializer=text_service.BatchEmbedTextResponse.deserialize, ) - return self._stubs['batch_embed_text'] + return self._stubs["batch_embed_text"] @property - def count_text_tokens(self) -> Callable[ - [text_service.CountTextTokensRequest], - text_service.CountTextTokensResponse]: + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], text_service.CountTextTokensResponse + ]: r"""Return a callable for the count text tokens method over gRPC. Runs a model's tokenizer on a text and returns the @@ -329,13 +338,13 @@ def count_text_tokens(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'count_text_tokens' not in self._stubs: - self._stubs['count_text_tokens'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/CountTextTokens', + if "count_text_tokens" not in self._stubs: + self._stubs["count_text_tokens"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/CountTextTokens", request_serializer=text_service.CountTextTokensRequest.serialize, response_deserializer=text_service.CountTextTokensResponse.deserialize, ) - return self._stubs['count_text_tokens'] + return self._stubs["count_text_tokens"] def close(self): self.grpc_channel.close() @@ -345,6 +354,4 @@ def kind(self) -> str: return "grpc" -__all__ = ( - 'TextServiceGrpcTransport', -) +__all__ = ("TextServiceGrpcTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py similarity index 79% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py index 46ac7dc2d417..06ef9a24926d 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/grpc_asyncio.py @@ -13,20 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers_async -from google.auth import credentials as ga_credentials # type: ignore +from google.api_core import gapic_v1, grpc_helpers_async +from google.auth import credentials as ga_credentials # type: ignore from google.auth.transport.grpc import SslCredentials # type: ignore - -import grpc # type: ignore +from google.longrunning import operations_pb2 # type: ignore +import grpc # type: ignore from grpc.experimental import aio # type: ignore from google.ai.generativelanguage_v1beta3.types import text_service -from google.longrunning import operations_pb2 # type: ignore -from .base import TextServiceTransport, DEFAULT_CLIENT_INFO + +from .base import DEFAULT_CLIENT_INFO, TextServiceTransport from .grpc import TextServiceGrpcTransport @@ -50,13 +49,15 @@ class TextServiceGrpcAsyncIOTransport(TextServiceTransport): _stubs: Dict[str, Callable] = {} @classmethod - def create_channel(cls, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - quota_project_id: Optional[str] = None, - **kwargs) -> aio.Channel: + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: """Create and return a gRPC AsyncIO channel object. Args: host (Optional[str]): The host for the channel to use. @@ -87,24 +88,26 @@ def create_channel(cls, default_scopes=cls.AUTH_SCOPES, scopes=scopes, default_host=cls.DEFAULT_HOST, - **kwargs + **kwargs, ) - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - channel: Optional[aio.Channel] = None, - api_mtls_endpoint: Optional[str] = None, - client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, - client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[aio.Channel] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -233,9 +236,11 @@ def grpc_channel(self) -> aio.Channel: return self._grpc_channel @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - Awaitable[text_service.GenerateTextResponse]]: + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], Awaitable[text_service.GenerateTextResponse] + ]: r"""Return a callable for the generate text method over gRPC. Generates a response from the model given an input @@ -251,18 +256,20 @@ def generate_text(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'generate_text' not in self._stubs: - self._stubs['generate_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/GenerateText', + if "generate_text" not in self._stubs: + self._stubs["generate_text"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/GenerateText", request_serializer=text_service.GenerateTextRequest.serialize, response_deserializer=text_service.GenerateTextResponse.deserialize, ) - return self._stubs['generate_text'] + return self._stubs["generate_text"] @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - Awaitable[text_service.EmbedTextResponse]]: + def embed_text( + self, + ) -> Callable[ + [text_service.EmbedTextRequest], Awaitable[text_service.EmbedTextResponse] + ]: r"""Return a callable for the embed text method over gRPC. Generates an embedding from the model given an input @@ -278,18 +285,21 @@ def embed_text(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'embed_text' not in self._stubs: - self._stubs['embed_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/EmbedText', + if "embed_text" not in self._stubs: + self._stubs["embed_text"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/EmbedText", request_serializer=text_service.EmbedTextRequest.serialize, response_deserializer=text_service.EmbedTextResponse.deserialize, ) - return self._stubs['embed_text'] + return self._stubs["embed_text"] @property - def batch_embed_text(self) -> Callable[ - [text_service.BatchEmbedTextRequest], - Awaitable[text_service.BatchEmbedTextResponse]]: + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], + Awaitable[text_service.BatchEmbedTextResponse], + ]: r"""Return a callable for the batch embed text method over gRPC. Generates multiple embeddings from the model given @@ -305,18 +315,21 @@ def batch_embed_text(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'batch_embed_text' not in self._stubs: - self._stubs['batch_embed_text'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/BatchEmbedText', + if "batch_embed_text" not in self._stubs: + self._stubs["batch_embed_text"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/BatchEmbedText", request_serializer=text_service.BatchEmbedTextRequest.serialize, response_deserializer=text_service.BatchEmbedTextResponse.deserialize, ) - return self._stubs['batch_embed_text'] + return self._stubs["batch_embed_text"] @property - def count_text_tokens(self) -> Callable[ - [text_service.CountTextTokensRequest], - Awaitable[text_service.CountTextTokensResponse]]: + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], + Awaitable[text_service.CountTextTokensResponse], + ]: r"""Return a callable for the count text tokens method over gRPC. Runs a model's tokenizer on a text and returns the @@ -332,18 +345,16 @@ def count_text_tokens(self) -> Callable[ # the request. # gRPC handles serialization and deserialization, so we just need # to pass in the functions for each. - if 'count_text_tokens' not in self._stubs: - self._stubs['count_text_tokens'] = self.grpc_channel.unary_unary( - '/google.ai.generativelanguage.v1beta3.TextService/CountTextTokens', + if "count_text_tokens" not in self._stubs: + self._stubs["count_text_tokens"] = self.grpc_channel.unary_unary( + "/google.ai.generativelanguage.v1beta3.TextService/CountTextTokens", request_serializer=text_service.CountTextTokensRequest.serialize, response_deserializer=text_service.CountTextTokensResponse.deserialize, ) - return self._stubs['count_text_tokens'] + return self._stubs["count_text_tokens"] def close(self): return self.grpc_channel.close() -__all__ = ( - 'TextServiceGrpcAsyncIOTransport', -) +__all__ = ("TextServiceGrpcAsyncIOTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py similarity index 70% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py index cdd184d866a4..dd0920f48a00 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/text_service/transports/rest.py @@ -14,24 +14,21 @@ # limitations under the License. # -from google.auth.transport.requests import AuthorizedSession # type: ignore +import dataclasses import json # type: ignore -import grpc # type: ignore -from google.auth.transport.grpc import SslCredentials # type: ignore -from google.auth import credentials as ga_credentials # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, path_template, rest_helpers, rest_streaming from google.api_core import exceptions as core_exceptions from google.api_core import retry as retries -from google.api_core import rest_helpers -from google.api_core import rest_streaming -from google.api_core import path_template -from google.api_core import gapic_v1 - +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore from google.protobuf import json_format +import grpc # type: ignore from requests import __version__ as requests_version -import dataclasses -import re -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union -import warnings try: OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault] @@ -39,11 +36,12 @@ OptionalRetry = Union[retries.Retry, object] # type: ignore -from google.ai.generativelanguage_v1beta3.types import text_service from google.longrunning import operations_pb2 # type: ignore -from .base import TextServiceTransport, DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from google.ai.generativelanguage_v1beta3.types import text_service +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .base import TextServiceTransport DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, @@ -104,7 +102,12 @@ def post_generate_text(self, response): """ - def pre_batch_embed_text(self, request: text_service.BatchEmbedTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.BatchEmbedTextRequest, Sequence[Tuple[str, str]]]: + + def pre_batch_embed_text( + self, + request: text_service.BatchEmbedTextRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[text_service.BatchEmbedTextRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for batch_embed_text Override in a subclass to manipulate the request or metadata @@ -112,7 +115,9 @@ def pre_batch_embed_text(self, request: text_service.BatchEmbedTextRequest, meta """ return request, metadata - def post_batch_embed_text(self, response: text_service.BatchEmbedTextResponse) -> text_service.BatchEmbedTextResponse: + def post_batch_embed_text( + self, response: text_service.BatchEmbedTextResponse + ) -> text_service.BatchEmbedTextResponse: """Post-rpc interceptor for batch_embed_text Override in a subclass to manipulate the response @@ -120,7 +125,12 @@ def post_batch_embed_text(self, response: text_service.BatchEmbedTextResponse) - it is returned to user code. """ return response - def pre_count_text_tokens(self, request: text_service.CountTextTokensRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.CountTextTokensRequest, Sequence[Tuple[str, str]]]: + + def pre_count_text_tokens( + self, + request: text_service.CountTextTokensRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[text_service.CountTextTokensRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for count_text_tokens Override in a subclass to manipulate the request or metadata @@ -128,7 +138,9 @@ def pre_count_text_tokens(self, request: text_service.CountTextTokensRequest, me """ return request, metadata - def post_count_text_tokens(self, response: text_service.CountTextTokensResponse) -> text_service.CountTextTokensResponse: + def post_count_text_tokens( + self, response: text_service.CountTextTokensResponse + ) -> text_service.CountTextTokensResponse: """Post-rpc interceptor for count_text_tokens Override in a subclass to manipulate the response @@ -136,7 +148,12 @@ def post_count_text_tokens(self, response: text_service.CountTextTokensResponse) it is returned to user code. """ return response - def pre_embed_text(self, request: text_service.EmbedTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.EmbedTextRequest, Sequence[Tuple[str, str]]]: + + def pre_embed_text( + self, + request: text_service.EmbedTextRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[text_service.EmbedTextRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for embed_text Override in a subclass to manipulate the request or metadata @@ -144,7 +161,9 @@ def pre_embed_text(self, request: text_service.EmbedTextRequest, metadata: Seque """ return request, metadata - def post_embed_text(self, response: text_service.EmbedTextResponse) -> text_service.EmbedTextResponse: + def post_embed_text( + self, response: text_service.EmbedTextResponse + ) -> text_service.EmbedTextResponse: """Post-rpc interceptor for embed_text Override in a subclass to manipulate the response @@ -152,7 +171,12 @@ def post_embed_text(self, response: text_service.EmbedTextResponse) -> text_serv it is returned to user code. """ return response - def pre_generate_text(self, request: text_service.GenerateTextRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[text_service.GenerateTextRequest, Sequence[Tuple[str, str]]]: + + def pre_generate_text( + self, + request: text_service.GenerateTextRequest, + metadata: Sequence[Tuple[str, str]], + ) -> Tuple[text_service.GenerateTextRequest, Sequence[Tuple[str, str]]]: """Pre-rpc interceptor for generate_text Override in a subclass to manipulate the request or metadata @@ -160,7 +184,9 @@ def pre_generate_text(self, request: text_service.GenerateTextRequest, metadata: """ return request, metadata - def post_generate_text(self, response: text_service.GenerateTextResponse) -> text_service.GenerateTextResponse: + def post_generate_text( + self, response: text_service.GenerateTextResponse + ) -> text_service.GenerateTextResponse: """Post-rpc interceptor for generate_text Override in a subclass to manipulate the response @@ -193,20 +219,21 @@ class TextServiceRestTransport(TextServiceTransport): """ - def __init__(self, *, - host: str = 'generativelanguage.googleapis.com', - credentials: Optional[ga_credentials.Credentials] = None, - credentials_file: Optional[str] = None, - scopes: Optional[Sequence[str]] = None, - client_cert_source_for_mtls: Optional[Callable[[ - ], Tuple[bytes, bytes]]] = None, - quota_project_id: Optional[str] = None, - client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, - always_use_jwt_access: Optional[bool] = False, - url_scheme: str = 'https', - interceptor: Optional[TextServiceRestInterceptor] = None, - api_audience: Optional[str] = None, - ) -> None: + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[TextServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: """Instantiate the transport. Args: @@ -245,7 +272,9 @@ def __init__(self, *, # credentials object maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) if maybe_url_match is None: - raise ValueError(f"Unexpected hostname structure: {host}") # pragma: NO COVER + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER url_match_items = maybe_url_match.groupdict() @@ -256,10 +285,11 @@ def __init__(self, *, credentials=credentials, client_info=client_info, always_use_jwt_access=always_use_jwt_access, - api_audience=api_audience + api_audience=api_audience, ) self._session = AuthorizedSession( - self._credentials, default_host=self.DEFAULT_HOST) + self._credentials, default_host=self.DEFAULT_HOST + ) if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) self._interceptor = interceptor or TextServiceRestInterceptor() @@ -269,19 +299,24 @@ class _BatchEmbedText(TextServiceRestStub): def __hash__(self): return hash("BatchEmbedText") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: text_service.BatchEmbedTextRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> text_service.BatchEmbedTextResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: text_service.BatchEmbedTextRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.BatchEmbedTextResponse: r"""Call the batch embed text method over HTTP. Args: @@ -299,46 +334,51 @@ def __call__(self, The response to a EmbedTextRequest. """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{model=models/*}:batchEmbedText', - 'body': '*', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{model=models/*}:batchEmbedText", + "body": "*", + }, ] - request, metadata = self._interceptor.pre_batch_embed_text(request, metadata) + request, metadata = self._interceptor.pre_batch_embed_text( + request, metadata + ) pb_request = text_service.BatchEmbedTextRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -357,19 +397,24 @@ class _CountTextTokens(TextServiceRestStub): def __hash__(self): return hash("CountTextTokens") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: text_service.CountTextTokensRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> text_service.CountTextTokensResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: text_service.CountTextTokensRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.CountTextTokensResponse: r"""Call the count text tokens method over HTTP. Args: @@ -394,46 +439,51 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{model=models/*}:countTextTokens', - 'body': '*', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{model=models/*}:countTextTokens", + "body": "*", + }, ] - request, metadata = self._interceptor.pre_count_text_tokens(request, metadata) + request, metadata = self._interceptor.pre_count_text_tokens( + request, metadata + ) pb_request = text_service.CountTextTokensRequest.pb(request) transcoded_request = path_template.transcode(http_options, pb_request) # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -452,19 +502,24 @@ class _EmbedText(TextServiceRestStub): def __hash__(self): return hash("EmbedText") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: text_service.EmbedTextRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> text_service.EmbedTextResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: text_service.EmbedTextRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.EmbedTextResponse: r"""Call the embed text method over HTTP. Args: @@ -482,11 +537,12 @@ def __call__(self, The response to a EmbedTextRequest. """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{model=models/*}:embedText', - 'body': '*', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{model=models/*}:embedText", + "body": "*", + }, ] request, metadata = self._interceptor.pre_embed_text(request, metadata) pb_request = text_service.EmbedTextRequest.pb(request) @@ -495,33 +551,35 @@ def __call__(self, # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -540,19 +598,24 @@ class _GenerateText(TextServiceRestStub): def __hash__(self): return hash("GenerateText") - __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { - } + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} @classmethod def _get_unset_required_fields(cls, message_dict): - return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} - - def __call__(self, - request: text_service.GenerateTextRequest, *, - retry: OptionalRetry=gapic_v1.method.DEFAULT, - timeout: Optional[float]=None, - metadata: Sequence[Tuple[str, str]]=(), - ) -> text_service.GenerateTextResponse: + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + def __call__( + self, + request: text_service.GenerateTextRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> text_service.GenerateTextResponse: r"""Call the generate text method over HTTP. Args: @@ -572,16 +635,17 @@ def __call__(self, """ - http_options: List[Dict[str, str]] = [{ - 'method': 'post', - 'uri': '/v1beta3/{model=models/*}:generateText', - 'body': '*', - }, -{ - 'method': 'post', - 'uri': '/v1beta3/{model=tunedModels/*}:generateText', - 'body': '*', - }, + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1beta3/{model=models/*}:generateText", + "body": "*", + }, + { + "method": "post", + "uri": "/v1beta3/{model=tunedModels/*}:generateText", + "body": "*", + }, ] request, metadata = self._interceptor.pre_generate_text(request, metadata) pb_request = text_service.GenerateTextRequest.pb(request) @@ -590,33 +654,35 @@ def __call__(self, # Jsonify the request body body = json_format.MessageToJson( - transcoded_request['body'], + transcoded_request["body"], including_default_value_fields=False, - use_integers_for_enums=True + use_integers_for_enums=True, ) - uri = transcoded_request['uri'] - method = transcoded_request['method'] + uri = transcoded_request["uri"] + method = transcoded_request["method"] # Jsonify the query params - query_params = json.loads(json_format.MessageToJson( - transcoded_request['query_params'], - including_default_value_fields=False, - use_integers_for_enums=True, - )) + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + including_default_value_fields=False, + use_integers_for_enums=True, + ) + ) query_params.update(self._get_unset_required_fields(query_params)) query_params["$alt"] = "json;enum-encoding=int" # Send the request headers = dict(metadata) - headers['Content-Type'] = 'application/json' + headers["Content-Type"] = "application/json" response = getattr(self._session, method)( "{host}{uri}".format(host=self._host, uri=uri), timeout=timeout, headers=headers, params=rest_helpers.flatten_query_params(query_params, strict=True), data=body, - ) + ) # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception # subclass. @@ -632,36 +698,42 @@ def __call__(self, return resp @property - def batch_embed_text(self) -> Callable[ - [text_service.BatchEmbedTextRequest], - text_service.BatchEmbedTextResponse]: + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], text_service.BatchEmbedTextResponse + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._BatchEmbedText(self._session, self._host, self._interceptor) # type: ignore + return self._BatchEmbedText(self._session, self._host, self._interceptor) # type: ignore @property - def count_text_tokens(self) -> Callable[ - [text_service.CountTextTokensRequest], - text_service.CountTextTokensResponse]: + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], text_service.CountTextTokensResponse + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._CountTextTokens(self._session, self._host, self._interceptor) # type: ignore + return self._CountTextTokens(self._session, self._host, self._interceptor) # type: ignore @property - def embed_text(self) -> Callable[ - [text_service.EmbedTextRequest], - text_service.EmbedTextResponse]: + def embed_text( + self, + ) -> Callable[[text_service.EmbedTextRequest], text_service.EmbedTextResponse]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._EmbedText(self._session, self._host, self._interceptor) # type: ignore + return self._EmbedText(self._session, self._host, self._interceptor) # type: ignore @property - def generate_text(self) -> Callable[ - [text_service.GenerateTextRequest], - text_service.GenerateTextResponse]: + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], text_service.GenerateTextResponse + ]: # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. # In C++ this would require a dynamic_cast - return self._GenerateText(self._session, self._host, self._interceptor) # type: ignore + return self._GenerateText(self._session, self._host, self._interceptor) # type: ignore @property def kind(self) -> str: @@ -671,6 +743,4 @@ def close(self): self._session.close() -__all__=( - 'TextServiceRestTransport', -) +__all__ = ("TextServiceRestTransport",) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/__init__.py similarity index 56% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/__init__.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/__init__.py index b2a054b00c36..85ef888f3ac9 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/__init__.py @@ -13,10 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .citation import ( - CitationMetadata, - CitationSource, -) +from .citation import CitationMetadata, CitationSource from .discuss_service import ( CountMessageTokensRequest, CountMessageTokensResponse, @@ -26,9 +23,7 @@ Message, MessagePrompt, ) -from .model import ( - Model, -) +from .model import Model from .model_service import ( CreateTunedModelMetadata, CreateTunedModelRequest, @@ -41,9 +36,7 @@ ListTunedModelsResponse, UpdateTunedModelRequest, ) -from .permission import ( - Permission, -) +from .permission import Permission from .permission_service import ( CreatePermissionRequest, DeletePermissionRequest, @@ -56,10 +49,10 @@ ) from .safety import ( ContentFilter, + HarmCategory, SafetyFeedback, SafetyRating, SafetySetting, - HarmCategory, ) from .text_service import ( BatchEmbedTextRequest, @@ -86,57 +79,57 @@ ) __all__ = ( - 'CitationMetadata', - 'CitationSource', - 'CountMessageTokensRequest', - 'CountMessageTokensResponse', - 'Example', - 'GenerateMessageRequest', - 'GenerateMessageResponse', - 'Message', - 'MessagePrompt', - 'Model', - 'CreateTunedModelMetadata', - 'CreateTunedModelRequest', - 'DeleteTunedModelRequest', - 'GetModelRequest', - 'GetTunedModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'ListTunedModelsRequest', - 'ListTunedModelsResponse', - 'UpdateTunedModelRequest', - 'Permission', - 'CreatePermissionRequest', - 'DeletePermissionRequest', - 'GetPermissionRequest', - 'ListPermissionsRequest', - 'ListPermissionsResponse', - 'TransferOwnershipRequest', - 'TransferOwnershipResponse', - 'UpdatePermissionRequest', - 'ContentFilter', - 'SafetyFeedback', - 'SafetyRating', - 'SafetySetting', - 'HarmCategory', - 'BatchEmbedTextRequest', - 'BatchEmbedTextResponse', - 'CountTextTokensRequest', - 'CountTextTokensResponse', - 'Embedding', - 'EmbedTextRequest', - 'EmbedTextResponse', - 'GenerateTextRequest', - 'GenerateTextResponse', - 'TextCompletion', - 'TextPrompt', - 'Dataset', - 'Hyperparameters', - 'TunedModel', - 'TunedModelSource', - 'TuningExample', - 'TuningExamples', - 'TuningSnapshot', - 'TuningTask', + "CitationMetadata", + "CitationSource", + "CountMessageTokensRequest", + "CountMessageTokensResponse", + "Example", + "GenerateMessageRequest", + "GenerateMessageResponse", + "Message", + "MessagePrompt", + "Model", + "CreateTunedModelMetadata", + "CreateTunedModelRequest", + "DeleteTunedModelRequest", + "GetModelRequest", + "GetTunedModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "ListTunedModelsRequest", + "ListTunedModelsResponse", + "UpdateTunedModelRequest", + "Permission", + "CreatePermissionRequest", + "DeletePermissionRequest", + "GetPermissionRequest", + "ListPermissionsRequest", + "ListPermissionsResponse", + "TransferOwnershipRequest", + "TransferOwnershipResponse", + "UpdatePermissionRequest", + "ContentFilter", + "SafetyFeedback", + "SafetyRating", + "SafetySetting", + "HarmCategory", + "BatchEmbedTextRequest", + "BatchEmbedTextResponse", + "CountTextTokensRequest", + "CountTextTokensResponse", + "Embedding", + "EmbedTextRequest", + "EmbedTextResponse", + "GenerateTextRequest", + "GenerateTextResponse", + "TextCompletion", + "TextPrompt", + "Dataset", + "Hyperparameters", + "TunedModel", + "TunedModelSource", + "TuningExample", + "TuningExamples", + "TuningSnapshot", + "TuningTask", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/citation.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/citation.py similarity index 92% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/citation.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/citation.py index f7ea0d176c60..26d17701a3ed 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/citation.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/citation.py @@ -19,12 +19,11 @@ import proto # type: ignore - __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'CitationMetadata', - 'CitationSource', + "CitationMetadata", + "CitationSource", }, ) @@ -37,10 +36,10 @@ class CitationMetadata(proto.Message): Citations to sources for a specific response. """ - citation_sources: MutableSequence['CitationSource'] = proto.RepeatedField( + citation_sources: MutableSequence["CitationSource"] = proto.RepeatedField( proto.MESSAGE, number=1, - message='CitationSource', + message="CitationSource", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/discuss_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/discuss_service.py similarity index 91% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/discuss_service.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/discuss_service.py index 4c731553dbbc..0bf6ff74f603 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/discuss_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/discuss_service.py @@ -19,20 +19,18 @@ import proto # type: ignore -from google.ai.generativelanguage_v1beta3.types import citation -from google.ai.generativelanguage_v1beta3.types import safety - +from google.ai.generativelanguage_v1beta3.types import citation, safety __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'GenerateMessageRequest', - 'GenerateMessageResponse', - 'Message', - 'MessagePrompt', - 'Example', - 'CountMessageTokensRequest', - 'CountMessageTokensResponse', + "GenerateMessageRequest", + "GenerateMessageResponse", + "Message", + "MessagePrompt", + "Example", + "CountMessageTokensRequest", + "CountMessageTokensResponse", }, ) @@ -96,10 +94,10 @@ class GenerateMessageRequest(proto.Message): proto.STRING, number=1, ) - prompt: 'MessagePrompt' = proto.Field( + prompt: "MessagePrompt" = proto.Field( proto.MESSAGE, number=2, - message='MessagePrompt', + message="MessagePrompt", ) temperature: float = proto.Field( proto.FLOAT, @@ -145,15 +143,15 @@ class GenerateMessageResponse(proto.Message): that category. """ - candidates: MutableSequence['Message'] = proto.RepeatedField( + candidates: MutableSequence["Message"] = proto.RepeatedField( proto.MESSAGE, number=1, - message='Message', + message="Message", ) - messages: MutableSequence['Message'] = proto.RepeatedField( + messages: MutableSequence["Message"] = proto.RepeatedField( proto.MESSAGE, number=2, - message='Message', + message="Message", ) filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( proto.MESSAGE, @@ -267,15 +265,15 @@ class MessagePrompt(proto.Message): proto.STRING, number=1, ) - examples: MutableSequence['Example'] = proto.RepeatedField( + examples: MutableSequence["Example"] = proto.RepeatedField( proto.MESSAGE, number=2, - message='Example', + message="Example", ) - messages: MutableSequence['Message'] = proto.RepeatedField( + messages: MutableSequence["Message"] = proto.RepeatedField( proto.MESSAGE, number=3, - message='Message', + message="Message", ) @@ -293,15 +291,15 @@ class Example(proto.Message): output given the input. """ - input: 'Message' = proto.Field( + input: "Message" = proto.Field( proto.MESSAGE, number=1, - message='Message', + message="Message", ) - output: 'Message' = proto.Field( + output: "Message" = proto.Field( proto.MESSAGE, number=2, - message='Message', + message="Message", ) @@ -329,10 +327,10 @@ class CountMessageTokensRequest(proto.Message): proto.STRING, number=1, ) - prompt: 'MessagePrompt' = proto.Field( + prompt: "MessagePrompt" = proto.Field( proto.MESSAGE, number=2, - message='MessagePrompt', + message="MessagePrompt", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/model.py similarity index 98% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/model.py index f5ac72ce872b..1e293084c980 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/model.py @@ -19,11 +19,10 @@ import proto # type: ignore - __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'Model', + "Model", }, ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/model_service.py similarity index 95% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model_service.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/model_service.py index f2f640d9da76..c5a88079a118 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/model_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/model_service.py @@ -17,26 +17,25 @@ from typing import MutableMapping, MutableSequence +from google.protobuf import field_mask_pb2 # type: ignore import proto # type: ignore -from google.ai.generativelanguage_v1beta3.types import model from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model -from google.protobuf import field_mask_pb2 # type: ignore - +from google.ai.generativelanguage_v1beta3.types import model __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'GetModelRequest', - 'ListModelsRequest', - 'ListModelsResponse', - 'GetTunedModelRequest', - 'ListTunedModelsRequest', - 'ListTunedModelsResponse', - 'CreateTunedModelRequest', - 'CreateTunedModelMetadata', - 'UpdateTunedModelRequest', - 'DeleteTunedModelRequest', + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "GetTunedModelRequest", + "ListTunedModelsRequest", + "ListTunedModelsResponse", + "CreateTunedModelRequest", + "CreateTunedModelMetadata", + "UpdateTunedModelRequest", + "DeleteTunedModelRequest", }, ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission.py similarity index 98% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission.py index 7a5b9c7c14b3..115ca22e8bef 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission.py @@ -19,11 +19,10 @@ import proto # type: ignore - __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'Permission', + "Permission", }, ) @@ -73,6 +72,7 @@ class Permission(proto.Message): This field is a member of `oneof`_ ``_role``. """ + class GranteeType(proto.Enum): r"""Defines types of the grantee of this permission. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission_service.py similarity index 93% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission_service.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission_service.py index cb9c76ef3167..5499c997e5d6 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/permission_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission_service.py @@ -17,23 +17,22 @@ from typing import MutableMapping, MutableSequence +from google.protobuf import field_mask_pb2 # type: ignore import proto # type: ignore from google.ai.generativelanguage_v1beta3.types import permission as gag_permission -from google.protobuf import field_mask_pb2 # type: ignore - __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'CreatePermissionRequest', - 'GetPermissionRequest', - 'ListPermissionsRequest', - 'ListPermissionsResponse', - 'UpdatePermissionRequest', - 'DeletePermissionRequest', - 'TransferOwnershipRequest', - 'TransferOwnershipResponse', + "CreatePermissionRequest", + "GetPermissionRequest", + "ListPermissionsRequest", + "ListPermissionsResponse", + "UpdatePermissionRequest", + "DeletePermissionRequest", + "TransferOwnershipRequest", + "TransferOwnershipResponse", }, ) @@ -213,8 +212,7 @@ class TransferOwnershipRequest(proto.Message): class TransferOwnershipResponse(proto.Message): - r"""Response from ``TransferOwnership``. - """ + r"""Response from ``TransferOwnership``.""" __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/safety.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/safety.py similarity index 94% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/safety.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/safety.py index f33e790f3577..95ac9ace5341 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/safety.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/safety.py @@ -19,15 +19,14 @@ import proto # type: ignore - __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'HarmCategory', - 'ContentFilter', - 'SafetyFeedback', - 'SafetyRating', - 'SafetySetting', + "HarmCategory", + "ContentFilter", + "SafetyFeedback", + "SafetyRating", + "SafetySetting", }, ) @@ -88,6 +87,7 @@ class ContentFilter(proto.Message): This field is a member of `oneof`_ ``_message``. """ + class BlockedReason(proto.Enum): r"""A list of reasons why content may have been blocked. @@ -133,15 +133,15 @@ class SafetyFeedback(proto.Message): Safety settings applied to the request. """ - rating: 'SafetyRating' = proto.Field( + rating: "SafetyRating" = proto.Field( proto.MESSAGE, number=1, - message='SafetyRating', + message="SafetyRating", ) - setting: 'SafetySetting' = proto.Field( + setting: "SafetySetting" = proto.Field( proto.MESSAGE, number=2, - message='SafetySetting', + message="SafetySetting", ) @@ -161,6 +161,7 @@ class SafetyRating(proto.Message): Required. The probability of harm for this content. """ + class HarmProbability(proto.Enum): r"""The probability that a piece of content is harmful. @@ -187,10 +188,10 @@ class HarmProbability(proto.Enum): MEDIUM = 3 HIGH = 4 - category: 'HarmCategory' = proto.Field( + category: "HarmCategory" = proto.Field( proto.ENUM, number=3, - enum='HarmCategory', + enum="HarmCategory", ) probability: HarmProbability = proto.Field( proto.ENUM, @@ -212,6 +213,7 @@ class SafetySetting(proto.Message): Required. Controls the probability threshold at which harm is blocked. """ + class HarmBlockThreshold(proto.Enum): r"""Block at and beyond a specified harm probability. @@ -235,10 +237,10 @@ class HarmBlockThreshold(proto.Enum): BLOCK_ONLY_HIGH = 3 BLOCK_NONE = 4 - category: 'HarmCategory' = proto.Field( + category: "HarmCategory" = proto.Field( proto.ENUM, number=3, - enum='HarmCategory', + enum="HarmCategory", ) threshold: HarmBlockThreshold = proto.Field( proto.ENUM, diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/text_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/text_service.py similarity index 93% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/text_service.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/text_service.py index d347de6f0728..bd10a8560faf 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/text_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/text_service.py @@ -19,24 +19,22 @@ import proto # type: ignore -from google.ai.generativelanguage_v1beta3.types import citation -from google.ai.generativelanguage_v1beta3.types import safety - +from google.ai.generativelanguage_v1beta3.types import citation, safety __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'GenerateTextRequest', - 'GenerateTextResponse', - 'TextPrompt', - 'TextCompletion', - 'EmbedTextRequest', - 'EmbedTextResponse', - 'BatchEmbedTextRequest', - 'BatchEmbedTextResponse', - 'Embedding', - 'CountTextTokensRequest', - 'CountTextTokensResponse', + "GenerateTextRequest", + "GenerateTextResponse", + "TextPrompt", + "TextCompletion", + "EmbedTextRequest", + "EmbedTextResponse", + "BatchEmbedTextRequest", + "BatchEmbedTextResponse", + "Embedding", + "CountTextTokensRequest", + "CountTextTokensResponse", }, ) @@ -142,10 +140,10 @@ class GenerateTextRequest(proto.Message): proto.STRING, number=1, ) - prompt: 'TextPrompt' = proto.Field( + prompt: "TextPrompt" = proto.Field( proto.MESSAGE, number=2, - message='TextPrompt', + message="TextPrompt", ) temperature: float = proto.Field( proto.FLOAT, @@ -207,10 +205,10 @@ class GenerateTextResponse(proto.Message): content filtering. """ - candidates: MutableSequence['TextCompletion'] = proto.RepeatedField( + candidates: MutableSequence["TextCompletion"] = proto.RepeatedField( proto.MESSAGE, number=1, - message='TextCompletion', + message="TextCompletion", ) filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( proto.MESSAGE, @@ -316,11 +314,11 @@ class EmbedTextResponse(proto.Message): This field is a member of `oneof`_ ``_embedding``. """ - embedding: 'Embedding' = proto.Field( + embedding: "Embedding" = proto.Field( proto.MESSAGE, number=1, optional=True, - message='Embedding', + message="Embedding", ) @@ -357,10 +355,10 @@ class BatchEmbedTextResponse(proto.Message): the input text. """ - embeddings: MutableSequence['Embedding'] = proto.RepeatedField( + embeddings: MutableSequence["Embedding"] = proto.RepeatedField( proto.MESSAGE, number=1, - message='Embedding', + message="Embedding", ) @@ -402,10 +400,10 @@ class CountTextTokensRequest(proto.Message): proto.STRING, number=1, ) - prompt: 'TextPrompt' = proto.Field( + prompt: "TextPrompt" = proto.Field( proto.MESSAGE, number=2, - message='TextPrompt', + message="TextPrompt", ) diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/tuned_model.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/tuned_model.py similarity index 92% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/tuned_model.py rename to packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/tuned_model.py index 5fb0a44053d4..69e363d8e55f 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/types/tuned_model.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/tuned_model.py @@ -17,22 +17,20 @@ from typing import MutableMapping, MutableSequence -import proto # type: ignore - from google.protobuf import timestamp_pb2 # type: ignore - +import proto # type: ignore __protobuf__ = proto.module( - package='google.ai.generativelanguage.v1beta3', + package="google.ai.generativelanguage.v1beta3", manifest={ - 'TunedModel', - 'TunedModelSource', - 'TuningTask', - 'Hyperparameters', - 'Dataset', - 'TuningExamples', - 'TuningExample', - 'TuningSnapshot', + "TunedModel", + "TunedModelSource", + "TuningTask", + "Hyperparameters", + "Dataset", + "TuningExamples", + "TuningExample", + "TuningSnapshot", }, ) @@ -118,6 +116,7 @@ class TunedModel(proto.Message): Required. The tuning task that creates the tuned model. """ + class State(proto.Enum): r"""The state of the tuned model. @@ -136,16 +135,16 @@ class State(proto.Enum): ACTIVE = 2 FAILED = 3 - tuned_model_source: 'TunedModelSource' = proto.Field( + tuned_model_source: "TunedModelSource" = proto.Field( proto.MESSAGE, number=3, - oneof='source_model', - message='TunedModelSource', + oneof="source_model", + message="TunedModelSource", ) base_model: str = proto.Field( proto.STRING, number=4, - oneof='source_model', + oneof="source_model", ) name: str = proto.Field( proto.STRING, @@ -189,10 +188,10 @@ class State(proto.Enum): number=9, message=timestamp_pb2.Timestamp, ) - tuning_task: 'TuningTask' = proto.Field( + tuning_task: "TuningTask" = proto.Field( proto.MESSAGE, number=10, - message='TuningTask', + message="TuningTask", ) @@ -251,20 +250,20 @@ class TuningTask(proto.Message): number=2, message=timestamp_pb2.Timestamp, ) - snapshots: MutableSequence['TuningSnapshot'] = proto.RepeatedField( + snapshots: MutableSequence["TuningSnapshot"] = proto.RepeatedField( proto.MESSAGE, number=3, - message='TuningSnapshot', + message="TuningSnapshot", ) - training_data: 'Dataset' = proto.Field( + training_data: "Dataset" = proto.Field( proto.MESSAGE, number=4, - message='Dataset', + message="Dataset", ) - hyperparameters: 'Hyperparameters' = proto.Field( + hyperparameters: "Hyperparameters" = proto.Field( proto.MESSAGE, number=5, - message='Hyperparameters', + message="Hyperparameters", ) @@ -325,11 +324,11 @@ class Dataset(proto.Message): This field is a member of `oneof`_ ``dataset``. """ - examples: 'TuningExamples' = proto.Field( + examples: "TuningExamples" = proto.Field( proto.MESSAGE, number=1, - oneof='dataset', - message='TuningExamples', + oneof="dataset", + message="TuningExamples", ) @@ -344,10 +343,10 @@ class TuningExamples(proto.Message): must be of the same type. """ - examples: MutableSequence['TuningExample'] = proto.RepeatedField( + examples: MutableSequence["TuningExample"] = proto.RepeatedField( proto.MESSAGE, number=1, - message='TuningExample', + message="TuningExample", ) @@ -368,7 +367,7 @@ class TuningExample(proto.Message): text_input: str = proto.Field( proto.STRING, number=1, - oneof='model_input', + oneof="model_input", ) output: str = proto.Field( proto.STRING, diff --git a/packages/google-ai-generativelanguage/noxfile.py b/packages/google-ai-generativelanguage/noxfile.py index 6f5debd52f23..9a2acd8b6787 100644 --- a/packages/google-ai-generativelanguage/noxfile.py +++ b/packages/google-ai-generativelanguage/noxfile.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # -# Copyright 2018 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_count_message_tokens_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_discuss_service_generate_message_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_create_tuned_model_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_delete_tuned_model_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_model_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_get_tuned_model_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_models_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_list_tuned_models_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_model_service_update_tuned_model_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_create_permission_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_delete_permission_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_get_permission_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_list_permissions_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_transfer_ownership_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_permission_service_update_permission_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_batch_embed_text_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_count_text_tokens_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_embed_text_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_async.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py rename to packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1beta3_generated_text_service_generate_text_sync.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json rename to packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json diff --git a/packages/google-ai-generativelanguage/scripts/decrypt-secrets.sh b/packages/google-ai-generativelanguage/scripts/decrypt-secrets.sh index 21f6d2a26d90..0018b421ddf8 100755 --- a/packages/google-ai-generativelanguage/scripts/decrypt-secrets.sh +++ b/packages/google-ai-generativelanguage/scripts/decrypt-secrets.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright 2015 Google Inc. All rights reserved. +# Copyright 2023 Google LLC All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/scripts/fixup_generativelanguage_v1beta3_keywords.py b/packages/google-ai-generativelanguage/scripts/fixup_generativelanguage_v1beta3_keywords.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/scripts/fixup_generativelanguage_v1beta3_keywords.py rename to packages/google-ai-generativelanguage/scripts/fixup_generativelanguage_v1beta3_keywords.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/__init__.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/__init__.py similarity index 100% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/google/ai/generativelanguage_v1beta3/services/__init__.py rename to packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/__init__.py diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py similarity index 70% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py rename to packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py index fff2f9a81134..e436a0612d62 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_discuss_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import os + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -21,37 +22,33 @@ except ImportError: # pragma: NO COVER import mock -import grpc -from grpc.experimental import aio from collections.abc import Iterable -from google.protobuf import json_format import json import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format -from google.ai.generativelanguage_v1beta3.services.discuss_service import DiscussServiceAsyncClient -from google.ai.generativelanguage_v1beta3.services.discuss_service import DiscussServiceClient -from google.ai.generativelanguage_v1beta3.services.discuss_service import transports -from google.ai.generativelanguage_v1beta3.types import citation -from google.ai.generativelanguage_v1beta3.types import discuss_service -from google.ai.generativelanguage_v1beta3.types import safety +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template from google.api_core import client_options from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import path_template +import google.auth from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError -from google.longrunning import operations_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account -import google.auth +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +from google.ai.generativelanguage_v1beta3.services.discuss_service import ( + DiscussServiceAsyncClient, + DiscussServiceClient, + transports, +) +from google.ai.generativelanguage_v1beta3.types import citation, discuss_service, safety def client_cert_source_callback(): @@ -62,7 +59,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -73,21 +74,40 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert DiscussServiceClient._get_default_mtls_endpoint(None) is None - assert DiscussServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert DiscussServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - - -@pytest.mark.parametrize("client_class,transport_name", [ - (DiscussServiceClient, "grpc"), - (DiscussServiceAsyncClient, "grpc_asyncio"), - (DiscussServiceClient, "rest"), -]) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), + ], +) def test_discuss_service_client_from_service_account_info(client_class, transport_name): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info, transport=transport_name) @@ -95,52 +115,68 @@ def test_discuss_service_client_from_service_account_info(client_class, transpor assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_class,transport_name", [ - (transports.DiscussServiceGrpcTransport, "grpc"), - (transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.DiscussServiceRestTransport, "rest"), -]) -def test_discuss_service_client_service_account_always_use_jwt(transport_class, transport_name): - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.DiscussServiceGrpcTransport, "grpc"), + (transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.DiscussServiceRestTransport, "rest"), + ], +) +def test_discuss_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=True) use_jwt.assert_called_once_with(True) - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=False) use_jwt.assert_not_called() -@pytest.mark.parametrize("client_class,transport_name", [ - (DiscussServiceClient, "grpc"), - (DiscussServiceAsyncClient, "grpc_asyncio"), - (DiscussServiceClient, "rest"), -]) +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), + ], +) def test_discuss_service_client_from_service_account_file(client_class, transport_name): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) - client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) @@ -156,30 +192,45 @@ def test_discuss_service_client_get_transport_class(): assert transport == transports.DiscussServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), -]) -@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) -@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) -def test_discuss_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + DiscussServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceAsyncClient), +) +def test_discuss_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ) + with mock.patch.object(DiscussServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(DiscussServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(DiscussServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( @@ -197,7 +248,7 @@ def test_discuss_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -215,7 +266,7 @@ def test_discuss_service_client_client_options(client_class, transport_class, tr # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -237,13 +288,15 @@ def test_discuss_service_client_client_options(client_class, transport_class, tr client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -258,8 +311,10 @@ def test_discuss_service_client_client_options(client_class, transport_class, tr api_audience=None, ) # Check the case api_endpoint is provided - options = client_options.ClientOptions(api_audience="https://language.googleapis.com") - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -271,29 +326,57 @@ def test_discuss_service_client_client_options(client_class, transport_class, tr quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, always_use_jwt_access=True, - api_audience="https://language.googleapis.com" - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "true"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "false"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "true"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "false"), -]) -@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) -@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "true"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "false"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "true"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + DiscussServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_discuss_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_discuss_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) @@ -318,10 +401,18 @@ def test_discuss_service_client_mtls_env_auto(client_class, transport_class, tra # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -344,9 +435,14 @@ def test_discuss_service_client_mtls_env_auto(client_class, transport_class, tra ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -362,19 +458,31 @@ def test_discuss_service_client_mtls_env_auto(client_class, transport_class, tra ) -@pytest.mark.parametrize("client_class", [ - DiscussServiceClient, DiscussServiceAsyncClient -]) -@mock.patch.object(DiscussServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceClient)) -@mock.patch.object(DiscussServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(DiscussServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class", [DiscussServiceClient, DiscussServiceAsyncClient] +) +@mock.patch.object( + DiscussServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceAsyncClient), +) def test_discuss_service_client_get_mtls_endpoint_and_cert_source(client_class): mock_client_cert_source = mock.Mock() # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source == mock_client_cert_source @@ -382,8 +490,12 @@ def test_discuss_service_client_get_mtls_endpoint_and_cert_source(client_class): with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): mock_client_cert_source = mock.Mock() mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source is None @@ -401,31 +513,52 @@ def test_discuss_service_client_get_mtls_endpoint_and_cert_source(client_class): # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_ENDPOINT assert cert_source is None # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), -]) -def test_discuss_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), + ], +) +def test_discuss_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. options = client_options.ClientOptions( scopes=["1", "2"], ) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -440,18 +573,32 @@ def test_discuss_service_client_client_options_scopes(client_class, transport_cl api_audience=None, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), - (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", None), -]) -def test_discuss_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + DiscussServiceClient, + transports.DiscussServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", None), + ], +) +def test_discuss_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -466,11 +613,14 @@ def test_discuss_service_client_client_options_credentials_file(client_class, tr api_audience=None, ) + def test_discuss_service_client_client_options_from_dict(): - with mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = DiscussServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -485,17 +635,30 @@ def test_discuss_service_client_client_options_from_dict(): ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", grpc_helpers), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), -]) -def test_discuss_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + DiscussServiceClient, + transports.DiscussServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_discuss_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -528,8 +691,7 @@ def test_discuss_service_client_create_channel_credentials_file(client_class, tr credentials=file_creds, credentials_file=None, quota_project_id=None, - default_scopes=( -), + default_scopes=(), scopes=None, default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -540,11 +702,14 @@ def test_discuss_service_client_create_channel_credentials_file(client_class, tr ) -@pytest.mark.parametrize("request_type", [ - discuss_service.GenerateMessageRequest, - dict, -]) -def test_generate_message(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.GenerateMessageRequest, + dict, + ], +) +def test_generate_message(request_type, transport: str = "grpc"): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -555,12 +720,9 @@ def test_generate_message(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = discuss_service.GenerateMessageResponse( - ) + call.return_value = discuss_service.GenerateMessageResponse() response = client.generate_message(request) # Establish that the underlying gRPC stub method was called. @@ -577,20 +739,21 @@ def test_generate_message_empty_call(): # i.e. request == None and no flattened fields passed, work. client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: client.generate_message() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == discuss_service.GenerateMessageRequest() + @pytest.mark.asyncio -async def test_generate_message_async(transport: str = 'grpc_asyncio', request_type=discuss_service.GenerateMessageRequest): +async def test_generate_message_async( + transport: str = "grpc_asyncio", request_type=discuss_service.GenerateMessageRequest +): client = DiscussServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -601,12 +764,11 @@ async def test_generate_message_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.GenerateMessageResponse() + ) response = await client.generate_message(request) # Establish that the underlying gRPC stub method was called. @@ -632,12 +794,10 @@ def test_generate_message_field_headers(): # a field header. Set these to a non-empty value. request = discuss_service.GenerateMessageRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: call.return_value = discuss_service.GenerateMessageResponse() client.generate_message(request) @@ -649,9 +809,9 @@ def test_generate_message_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -664,13 +824,13 @@ async def test_generate_message_field_headers_async(): # a field header. Set these to a non-empty value. request = discuss_service.GenerateMessageRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.GenerateMessageResponse() + ) await client.generate_message(request) # Establish that the underlying gRPC stub method was called. @@ -681,9 +841,9 @@ async def test_generate_message_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] def test_generate_message_flattened(): @@ -692,16 +852,14 @@ def test_generate_message_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = discuss_service.GenerateMessageResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.generate_message( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), temperature=0.1198, candidate_count=1573, top_p=0.546, @@ -713,10 +871,10 @@ def test_generate_message_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') + mock_val = discuss_service.MessagePrompt(context="context_value") assert arg == mock_val assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) arg = args[0].candidate_count @@ -738,14 +896,15 @@ def test_generate_message_flattened_error(): with pytest.raises(ValueError): client.generate_message( discuss_service.GenerateMessageRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), temperature=0.1198, candidate_count=1573, top_p=0.546, top_k=541, ) + @pytest.mark.asyncio async def test_generate_message_flattened_async(): client = DiscussServiceAsyncClient( @@ -753,18 +912,18 @@ async def test_generate_message_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_message), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = discuss_service.GenerateMessageResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.GenerateMessageResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.GenerateMessageResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.generate_message( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), temperature=0.1198, candidate_count=1573, top_p=0.546, @@ -776,10 +935,10 @@ async def test_generate_message_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') + mock_val = discuss_service.MessagePrompt(context="context_value") assert arg == mock_val assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) arg = args[0].candidate_count @@ -790,6 +949,7 @@ async def test_generate_message_flattened_async(): mock_val = 541 assert arg == mock_val + @pytest.mark.asyncio async def test_generate_message_flattened_error_async(): client = DiscussServiceAsyncClient( @@ -801,8 +961,8 @@ async def test_generate_message_flattened_error_async(): with pytest.raises(ValueError): await client.generate_message( discuss_service.GenerateMessageRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), temperature=0.1198, candidate_count=1573, top_p=0.546, @@ -810,11 +970,14 @@ async def test_generate_message_flattened_error_async(): ) -@pytest.mark.parametrize("request_type", [ - discuss_service.CountMessageTokensRequest, - dict, -]) -def test_count_message_tokens(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.CountMessageTokensRequest, + dict, + ], +) +def test_count_message_tokens(request_type, transport: str = "grpc"): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -826,8 +989,8 @@ def test_count_message_tokens(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: + type(client.transport.count_message_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = discuss_service.CountMessageTokensResponse( token_count=1193, @@ -849,20 +1012,24 @@ def test_count_message_tokens_empty_call(): # i.e. request == None and no flattened fields passed, work. client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: + type(client.transport.count_message_tokens), "__call__" + ) as call: client.count_message_tokens() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == discuss_service.CountMessageTokensRequest() + @pytest.mark.asyncio -async def test_count_message_tokens_async(transport: str = 'grpc_asyncio', request_type=discuss_service.CountMessageTokensRequest): +async def test_count_message_tokens_async( + transport: str = "grpc_asyncio", + request_type=discuss_service.CountMessageTokensRequest, +): client = DiscussServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -874,12 +1041,14 @@ async def test_count_message_tokens_async(transport: str = 'grpc_asyncio', reque # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: + type(client.transport.count_message_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse( - token_count=1193, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + ) response = await client.count_message_tokens(request) # Establish that the underlying gRPC stub method was called. @@ -906,12 +1075,12 @@ def test_count_message_tokens_field_headers(): # a field header. Set these to a non-empty value. request = discuss_service.CountMessageTokensRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: + type(client.transport.count_message_tokens), "__call__" + ) as call: call.return_value = discuss_service.CountMessageTokensResponse() client.count_message_tokens(request) @@ -923,9 +1092,9 @@ def test_count_message_tokens_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -938,13 +1107,15 @@ async def test_count_message_tokens_field_headers_async(): # a field header. Set these to a non-empty value. request = discuss_service.CountMessageTokensRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) + type(client.transport.count_message_tokens), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.CountMessageTokensResponse() + ) await client.count_message_tokens(request) # Establish that the underlying gRPC stub method was called. @@ -955,9 +1126,9 @@ async def test_count_message_tokens_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] def test_count_message_tokens_flattened(): @@ -967,15 +1138,15 @@ def test_count_message_tokens_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: + type(client.transport.count_message_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = discuss_service.CountMessageTokensResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.count_message_tokens( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), ) # Establish that the underlying call was made with the expected @@ -983,10 +1154,10 @@ def test_count_message_tokens_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') + mock_val = discuss_service.MessagePrompt(context="context_value") assert arg == mock_val @@ -1000,10 +1171,11 @@ def test_count_message_tokens_flattened_error(): with pytest.raises(ValueError): client.count_message_tokens( discuss_service.CountMessageTokensRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), ) + @pytest.mark.asyncio async def test_count_message_tokens_flattened_async(): client = DiscussServiceAsyncClient( @@ -1012,17 +1184,19 @@ async def test_count_message_tokens_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_message_tokens), - '__call__') as call: + type(client.transport.count_message_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = discuss_service.CountMessageTokensResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(discuss_service.CountMessageTokensResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.CountMessageTokensResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.count_message_tokens( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), ) # Establish that the underlying call was made with the expected @@ -1030,12 +1204,13 @@ async def test_count_message_tokens_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = discuss_service.MessagePrompt(context='context_value') + mock_val = discuss_service.MessagePrompt(context="context_value") assert arg == mock_val + @pytest.mark.asyncio async def test_count_message_tokens_flattened_error_async(): client = DiscussServiceAsyncClient( @@ -1047,15 +1222,18 @@ async def test_count_message_tokens_flattened_error_async(): with pytest.raises(ValueError): await client.count_message_tokens( discuss_service.CountMessageTokensRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), ) -@pytest.mark.parametrize("request_type", [ - discuss_service.GenerateMessageRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.GenerateMessageRequest, + dict, + ], +) def test_generate_message_rest(request_type): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -1063,14 +1241,13 @@ def test_generate_message_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = discuss_service.GenerateMessageResponse( - ) + return_value = discuss_service.GenerateMessageResponse() # Wrap the value into a proper Response obj response_value = Response() @@ -1078,7 +1255,7 @@ def test_generate_message_rest(request_type): pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.generate_message(request) @@ -1086,58 +1263,66 @@ def test_generate_message_rest(request_type): assert isinstance(response, discuss_service.GenerateMessageResponse) -def test_generate_message_rest_required_fields(request_type=discuss_service.GenerateMessageRequest): +def test_generate_message_rest_required_fields( + request_type=discuss_service.GenerateMessageRequest, +): transport_class = transports.DiscussServiceRestTransport request_init = {} request_init["model"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_message._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["model"] = 'model_value' + jsonified_request["model"] = "model_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_message._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_message._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' + assert jsonified_request["model"] == "model_value" client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = discuss_service.GenerateMessageResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -1146,39 +1331,56 @@ def test_generate_message_rest_required_fields(request_type=discuss_service.Gene pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.generate_message(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_generate_message_rest_unset_required_fields(): - transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.generate_message._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_generate_message_rest_interceptors(null_interceptor): transport = transports.DiscussServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.DiscussServiceRestInterceptor(), + ) client = DiscussServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "post_generate_message") as post, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_generate_message") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DiscussServiceRestInterceptor, "post_generate_message" + ) as post, mock.patch.object( + transports.DiscussServiceRestInterceptor, "pre_generate_message" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = discuss_service.GenerateMessageRequest.pb(discuss_service.GenerateMessageRequest()) + pb_message = discuss_service.GenerateMessageRequest.pb( + discuss_service.GenerateMessageRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -1189,34 +1391,46 @@ def test_generate_message_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = discuss_service.GenerateMessageResponse.to_json(discuss_service.GenerateMessageResponse()) + req.return_value._content = discuss_service.GenerateMessageResponse.to_json( + discuss_service.GenerateMessageResponse() + ) request = discuss_service.GenerateMessageRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = discuss_service.GenerateMessageResponse() - client.generate_message(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.generate_message( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_generate_message_rest_bad_request(transport: str = 'rest', request_type=discuss_service.GenerateMessageRequest): +def test_generate_message_rest_bad_request( + transport: str = "rest", request_type=discuss_service.GenerateMessageRequest +): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -1232,17 +1446,17 @@ def test_generate_message_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = discuss_service.GenerateMessageResponse() # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} + sample_request = {"model": "models/sample1"} # get truthy value for each flattened field mock_args = dict( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), temperature=0.1198, candidate_count=1573, top_p=0.546, @@ -1255,7 +1469,7 @@ def test_generate_message_rest_flattened(): response_value.status_code = 200 pb_return_value = discuss_service.GenerateMessageResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.generate_message(**mock_args) @@ -1264,10 +1478,13 @@ def test_generate_message_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{model=models/*}:generateMessage" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{model=models/*}:generateMessage" % client.transport._host, + args[1], + ) -def test_generate_message_rest_flattened_error(transport: str = 'rest'): +def test_generate_message_rest_flattened_error(transport: str = "rest"): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1278,8 +1495,8 @@ def test_generate_message_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.generate_message( discuss_service.GenerateMessageRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), temperature=0.1198, candidate_count=1573, top_p=0.546, @@ -1289,15 +1506,17 @@ def test_generate_message_rest_flattened_error(transport: str = 'rest'): def test_generate_message_rest_error(): client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - discuss_service.CountMessageTokensRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.CountMessageTokensRequest, + dict, + ], +) def test_count_message_tokens_rest(request_type): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -1305,14 +1524,14 @@ def test_count_message_tokens_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = discuss_service.CountMessageTokensResponse( - token_count=1193, + token_count=1193, ) # Wrap the value into a proper Response obj @@ -1321,7 +1540,7 @@ def test_count_message_tokens_rest(request_type): pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.count_message_tokens(request) @@ -1330,99 +1549,126 @@ def test_count_message_tokens_rest(request_type): assert response.token_count == 1193 -def test_count_message_tokens_rest_required_fields(request_type=discuss_service.CountMessageTokensRequest): +def test_count_message_tokens_rest_required_fields( + request_type=discuss_service.CountMessageTokensRequest, +): transport_class = transports.DiscussServiceRestTransport request_init = {} request_init["model"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_message_tokens._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["model"] = 'model_value' + jsonified_request["model"] = "model_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_message_tokens._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_message_tokens._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' + assert jsonified_request["model"] == "model_value" client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = discuss_service.CountMessageTokensResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + pb_return_value = discuss_service.CountMessageTokensResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.count_message_tokens(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_count_message_tokens_rest_unset_required_fields(): - transport = transports.DiscussServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.count_message_tokens._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_count_message_tokens_rest_interceptors(null_interceptor): transport = transports.DiscussServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.DiscussServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.DiscussServiceRestInterceptor(), + ) client = DiscussServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "post_count_message_tokens") as post, \ - mock.patch.object(transports.DiscussServiceRestInterceptor, "pre_count_message_tokens") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DiscussServiceRestInterceptor, "post_count_message_tokens" + ) as post, mock.patch.object( + transports.DiscussServiceRestInterceptor, "pre_count_message_tokens" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = discuss_service.CountMessageTokensRequest.pb(discuss_service.CountMessageTokensRequest()) + pb_message = discuss_service.CountMessageTokensRequest.pb( + discuss_service.CountMessageTokensRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -1433,34 +1679,46 @@ def test_count_message_tokens_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = discuss_service.CountMessageTokensResponse.to_json(discuss_service.CountMessageTokensResponse()) + req.return_value._content = discuss_service.CountMessageTokensResponse.to_json( + discuss_service.CountMessageTokensResponse() + ) request = discuss_service.CountMessageTokensRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = discuss_service.CountMessageTokensResponse() - client.count_message_tokens(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.count_message_tokens( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_count_message_tokens_rest_bad_request(transport: str = 'rest', request_type=discuss_service.CountMessageTokensRequest): +def test_count_message_tokens_rest_bad_request( + transport: str = "rest", request_type=discuss_service.CountMessageTokensRequest +): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -1476,17 +1734,17 @@ def test_count_message_tokens_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = discuss_service.CountMessageTokensResponse() # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} + sample_request = {"model": "models/sample1"} # get truthy value for each flattened field mock_args = dict( - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), ) mock_args.update(sample_request) @@ -1495,7 +1753,7 @@ def test_count_message_tokens_rest_flattened(): response_value.status_code = 200 pb_return_value = discuss_service.CountMessageTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.count_message_tokens(**mock_args) @@ -1504,10 +1762,13 @@ def test_count_message_tokens_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{model=models/*}:countMessageTokens" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{model=models/*}:countMessageTokens" % client.transport._host, + args[1], + ) -def test_count_message_tokens_rest_flattened_error(transport: str = 'rest'): +def test_count_message_tokens_rest_flattened_error(transport: str = "rest"): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1518,15 +1779,14 @@ def test_count_message_tokens_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.count_message_tokens( discuss_service.CountMessageTokensRequest(), - model='model_value', - prompt=discuss_service.MessagePrompt(context='context_value'), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), ) def test_count_message_tokens_rest_error(): client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -1568,8 +1828,7 @@ def test_credentials_transport_error(): options.api_key = "api_key" with pytest.raises(ValueError): client = DiscussServiceClient( - client_options=options, - credentials=ga_credentials.AnonymousCredentials() + client_options=options, credentials=ga_credentials.AnonymousCredentials() ) # It is an error to provide scopes and a transport instance. @@ -1591,6 +1850,7 @@ def test_transport_instance(): client = DiscussServiceClient(transport=transport) assert client.transport is transport + def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.DiscussServiceGrpcTransport( @@ -1605,28 +1865,37 @@ def test_transport_get_channel(): channel = transport.grpc_channel assert channel -@pytest.mark.parametrize("transport_class", [ - transports.DiscussServiceGrpcTransport, - transports.DiscussServiceGrpcAsyncIOTransport, - transports.DiscussServiceRestTransport, -]) + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + transports.DiscussServiceRestTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() -@pytest.mark.parametrize("transport_name", [ - "grpc", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) def test_transport_kind(transport_name): transport = DiscussServiceClient.get_transport_class(transport_name)( credentials=ga_credentials.AnonymousCredentials(), ) assert transport.kind == transport_name + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = DiscussServiceClient( @@ -1637,18 +1906,21 @@ def test_transport_grpc_default(): transports.DiscussServiceGrpcTransport, ) + def test_discuss_service_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.DiscussServiceTransport( credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_discuss_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport.__init__') as Transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.DiscussServiceTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -1657,8 +1929,8 @@ def test_discuss_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'generate_message', - 'count_message_tokens', + "generate_message", + "count_message_tokens", ) for method in methods: with pytest.raises(NotImplementedError): @@ -1669,7 +1941,7 @@ def test_discuss_service_base_transport(): # Catch all for all remaining methods and properties remainder = [ - 'kind', + "kind", ] for r in remainder: with pytest.raises(NotImplementedError): @@ -1678,24 +1950,30 @@ def test_discuss_service_base_transport(): def test_discuss_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.DiscussServiceTransport( credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", + load_creds.assert_called_once_with( + "credentials.json", scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id="octopus", ) def test_discuss_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1beta3.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.DiscussServiceTransport() @@ -1704,13 +1982,12 @@ def test_discuss_service_base_transport_with_adc(): def test_discuss_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) DiscussServiceClient() adc.assert_called_once_with( scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id=None, ) @@ -1725,7 +2002,7 @@ def test_discuss_service_auth_adc(): def test_discuss_service_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( @@ -1744,47 +2021,45 @@ def test_discuss_service_transport_auth_adc(transport_class): ], ) def test_discuss_service_transport_auth_gdch_credentials(transport_class): - host = 'https://language.com' - api_audience_tests = [None, 'https://language2.com'] - api_audience_expect = [host, 'https://language2.com'] + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] for t, e in zip(api_audience_tests, api_audience_expect): - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: gdch_mock = mock.MagicMock() - type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) adc.return_value = (gdch_mock, None) transport_class(host=host, api_audience=t) - gdch_mock.with_gdch_audience.assert_called_once_with( - e - ) + gdch_mock.with_gdch_audience.assert_called_once_with(e) @pytest.mark.parametrize( "transport_class,grpc_helpers", [ (transports.DiscussServiceGrpcTransport, grpc_helpers), - (transports.DiscussServiceGrpcAsyncIOTransport, grpc_helpers_async) + (transports.DiscussServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) def test_discuss_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: creds = ga_credentials.AnonymousCredentials() adc.return_value = (creds, None) - transport_class( - quota_project_id="octopus", - scopes=["1", "2"] - ) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) create_channel.assert_called_with( "generativelanguage.googleapis.com:443", credentials=creds, credentials_file=None, quota_project_id="octopus", - default_scopes=( -), + default_scopes=(), scopes=["1", "2"], default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -1795,10 +2070,14 @@ def test_discuss_service_transport_create_channel(transport_class, grpc_helpers) ) -@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) -def test_discuss_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = ga_credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -1807,7 +2086,7 @@ def test_discuss_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", @@ -1828,61 +2107,77 @@ def test_discuss_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) + def test_discuss_service_http_transport_client_cert_source_for_mtls(): cred = ga_credentials.AnonymousCredentials() - with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: - transports.DiscussServiceRestTransport ( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.DiscussServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback ) mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_discuss_service_host_no_port(transport_name): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), - transport=transport_name, + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_discuss_service_host_with_port(transport_name): client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:8000' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com:8000' + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" ) -@pytest.mark.parametrize("transport_name", [ - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) def test_discuss_service_client_transport_session_collision(transport_name): creds1 = ga_credentials.AnonymousCredentials() creds2 = ga_credentials.AnonymousCredentials() @@ -1900,8 +2195,10 @@ def test_discuss_service_client_transport_session_collision(transport_name): session1 = client1.transport.count_message_tokens._session session2 = client2.transport.count_message_tokens._session assert session1 != session2 + + def test_discuss_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DiscussServiceGrpcTransport( @@ -1914,7 +2211,7 @@ def test_discuss_service_grpc_transport_channel(): def test_discuss_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DiscussServiceGrpcAsyncIOTransport( @@ -1928,12 +2225,22 @@ def test_discuss_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) def test_discuss_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1942,7 +2249,7 @@ def test_discuss_service_transport_channel_mtls_with_client_cert_source( cred = ga_credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -1972,17 +2279,23 @@ def test_discuss_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.DiscussServiceGrpcTransport, transports.DiscussServiceGrpcAsyncIOTransport]) -def test_discuss_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2012,7 +2325,9 @@ def test_discuss_service_transport_channel_mtls_with_adc( def test_model_path(): model = "squid" - expected = "models/{model}".format(model=model, ) + expected = "models/{model}".format( + model=model, + ) actual = DiscussServiceClient.model_path(model) assert expected == actual @@ -2027,9 +2342,12 @@ def test_parse_model_path(): actual = DiscussServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "whelk" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = DiscussServiceClient.common_billing_account_path(billing_account) assert expected == actual @@ -2044,9 +2362,12 @@ def test_parse_common_billing_account_path(): actual = DiscussServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "oyster" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format( + folder=folder, + ) actual = DiscussServiceClient.common_folder_path(folder) assert expected == actual @@ -2061,9 +2382,12 @@ def test_parse_common_folder_path(): actual = DiscussServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "cuttlefish" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format( + organization=organization, + ) actual = DiscussServiceClient.common_organization_path(organization) assert expected == actual @@ -2078,9 +2402,12 @@ def test_parse_common_organization_path(): actual = DiscussServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "winkle" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format( + project=project, + ) actual = DiscussServiceClient.common_project_path(project) assert expected == actual @@ -2095,10 +2422,14 @@ def test_parse_common_project_path(): actual = DiscussServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "scallop" location = "abalone" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) actual = DiscussServiceClient.common_location_path(project, location) assert expected == actual @@ -2118,14 +2449,18 @@ def test_parse_common_location_path(): def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DiscussServiceTransport, "_prep_wrapped_messages" + ) as prep: client = DiscussServiceClient( credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.DiscussServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.DiscussServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = DiscussServiceClient.get_transport_class() transport = transport_class( credentials=ga_credentials.AnonymousCredentials(), @@ -2133,13 +2468,16 @@ def test_client_with_default_client_info(): ) prep.assert_called_once_with(client_info) + @pytest.mark.asyncio async def test_transport_close_async(): client = DiscussServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: async with client: close.assert_not_called() close.assert_called_once() @@ -2153,23 +2491,24 @@ def test_transport_close(): for transport, close_name in transports.items(): client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) - with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: with client: close.assert_not_called() close.assert_called_once() + def test_client_ctx(): transports = [ - 'rest', - 'grpc', + "rest", + "grpc", ] for transport in transports: client = DiscussServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -2178,10 +2517,14 @@ def test_client_ctx(): pass close.assert_called() -@pytest.mark.parametrize("client_class,transport_class", [ - (DiscussServiceClient, transports.DiscussServiceGrpcTransport), - (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport), -]) + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport), + ], +) def test_api_key_credentials(client_class, transport_class): with mock.patch.object( google.auth._default, "get_api_key_credentials", create=True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py similarity index 68% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py rename to packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py index b0c5932de677..4a8fbc67129d 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_model_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import os + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -21,45 +22,47 @@ except ImportError: # pragma: NO COVER import mock -import grpc -from grpc.experimental import aio from collections.abc import Iterable -from google.protobuf import json_format import json import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format -from google.ai.generativelanguage_v1beta3.services.model_service import ModelServiceAsyncClient -from google.ai.generativelanguage_v1beta3.services.model_service import ModelServiceClient -from google.ai.generativelanguage_v1beta3.services.model_service import pagers -from google.ai.generativelanguage_v1beta3.services.model_service import transports -from google.ai.generativelanguage_v1beta3.types import model -from google.ai.generativelanguage_v1beta3.types import model_service -from google.ai.generativelanguage_v1beta3.types import tuned_model -from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.api_core import ( + future, + gapic_v1, + grpc_helpers, + grpc_helpers_async, + operation, + operations_v1, + path_template, +) from google.api_core import client_options from google.api_core import exceptions as core_exceptions -from google.api_core import future -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import operation from google.api_core import operation_async # type: ignore -from google.api_core import operations_v1 -from google.api_core import path_template +import google.auth from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError -from google.longrunning import operations_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import json_format from google.protobuf import timestamp_pb2 # type: ignore -import google.auth +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +from google.ai.generativelanguage_v1beta3.services.model_service import ( + ModelServiceAsyncClient, + ModelServiceClient, + pagers, + transports, +) +from google.ai.generativelanguage_v1beta3.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1beta3.types import model, model_service +from google.ai.generativelanguage_v1beta3.types import tuned_model def client_cert_source_callback(): @@ -70,7 +73,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -81,21 +88,37 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert ModelServiceClient._get_default_mtls_endpoint(None) is None - assert ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class,transport_name", [ - (ModelServiceClient, "grpc"), - (ModelServiceAsyncClient, "grpc_asyncio"), - (ModelServiceClient, "rest"), -]) +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), + ], +) def test_model_service_client_from_service_account_info(client_class, transport_name): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info, transport=transport_name) @@ -103,52 +126,68 @@ def test_model_service_client_from_service_account_info(client_class, transport_ assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_class,transport_name", [ - (transports.ModelServiceGrpcTransport, "grpc"), - (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.ModelServiceRestTransport, "rest"), -]) -def test_model_service_client_service_account_always_use_jwt(transport_class, transport_name): - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.ModelServiceGrpcTransport, "grpc"), + (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.ModelServiceRestTransport, "rest"), + ], +) +def test_model_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=True) use_jwt.assert_called_once_with(True) - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=False) use_jwt.assert_not_called() -@pytest.mark.parametrize("client_class,transport_name", [ - (ModelServiceClient, "grpc"), - (ModelServiceAsyncClient, "grpc_asyncio"), - (ModelServiceClient, "rest"), -]) +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), + ], +) def test_model_service_client_from_service_account_file(client_class, transport_name): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) - client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) @@ -164,30 +203,43 @@ def test_model_service_client_get_transport_class(): assert transport == transports.ModelServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) -def test_model_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ) + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(ModelServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( @@ -205,7 +257,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -223,7 +275,7 @@ def test_model_service_client_client_options(client_class, transport_class, tran # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -245,13 +297,15 @@ def test_model_service_client_client_options(client_class, transport_class, tran client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -266,8 +320,10 @@ def test_model_service_client_client_options(client_class, transport_class, tran api_audience=None, ) # Check the case api_endpoint is provided - options = client_options.ClientOptions(api_audience="https://language.googleapis.com") - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -279,29 +335,55 @@ def test_model_service_client_client_options(client_class, transport_class, tran quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, always_use_jwt_access=True, - api_audience="https://language.googleapis.com" - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "true"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "false"), -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "true"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_model_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) @@ -326,10 +408,18 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -352,9 +442,14 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -370,19 +465,27 @@ def test_model_service_client_mtls_env_auto(client_class, transport_class, trans ) -@pytest.mark.parametrize("client_class", [ - ModelServiceClient, ModelServiceAsyncClient -]) -@mock.patch.object(ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient)) -@mock.patch.object(ModelServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceAsyncClient)) +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): mock_client_cert_source = mock.Mock() # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source == mock_client_cert_source @@ -390,8 +493,12 @@ def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): mock_client_cert_source = mock.Mock() mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source is None @@ -409,31 +516,52 @@ def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_ENDPOINT assert cert_source is None # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), -]) -def test_model_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. options = client_options.ClientOptions( scopes=["1", "2"], ) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -448,18 +576,32 @@ def test_model_service_client_client_options_scopes(client_class, transport_clas api_audience=None, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), - (ModelServiceClient, transports.ModelServiceRestTransport, "rest", None), -]) -def test_model_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ModelServiceClient, + transports.ModelServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", None), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -474,12 +616,13 @@ def test_model_service_client_client_options_credentials_file(client_class, tran api_audience=None, ) + def test_model_service_client_client_options_from_dict(): - with mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = ModelServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -493,17 +636,30 @@ def test_model_service_client_client_options_from_dict(): ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", grpc_helpers), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), -]) -def test_model_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ModelServiceClient, + transports.ModelServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_model_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -536,8 +692,7 @@ def test_model_service_client_create_channel_credentials_file(client_class, tran credentials=file_creds, credentials_file=None, quota_project_id=None, - default_scopes=( -), + default_scopes=(), scopes=None, default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -548,11 +703,14 @@ def test_model_service_client_create_channel_credentials_file(client_class, tran ) -@pytest.mark.parametrize("request_type", [ - model_service.GetModelRequest, - dict, -]) -def test_get_model(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetModelRequest, + dict, + ], +) +def test_get_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -563,19 +721,17 @@ def test_get_model(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model( - name='name_value', - base_model_id='base_model_id_value', - version='version_value', - display_name='display_name_value', - description='description_value', + name="name_value", + base_model_id="base_model_id_value", + version="version_value", + display_name="display_name_value", + description="description_value", input_token_limit=1838, output_token_limit=1967, - supported_generation_methods=['supported_generation_methods_value'], + supported_generation_methods=["supported_generation_methods_value"], temperature=0.1198, top_p=0.546, top_k=541, @@ -589,14 +745,16 @@ def test_get_model(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' - assert response.base_model_id == 'base_model_id_value' - assert response.version == 'version_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.base_model_id == "base_model_id_value" + assert response.version == "version_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert response.input_token_limit == 1838 assert response.output_token_limit == 1967 - assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert response.supported_generation_methods == [ + "supported_generation_methods_value" + ] assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 @@ -607,20 +765,21 @@ def test_get_model_empty_call(): # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: client.get_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetModelRequest() + @pytest.mark.asyncio -async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetModelRequest): +async def test_get_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest +): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -631,23 +790,23 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model.Model( - name='name_value', - base_model_id='base_model_id_value', - version='version_value', - display_name='display_name_value', - description='description_value', - input_token_limit=1838, - output_token_limit=1967, - supported_generation_methods=['supported_generation_methods_value'], - temperature=0.1198, - top_p=0.546, - top_k=541, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + base_model_id="base_model_id_value", + version="version_value", + display_name="display_name_value", + description="description_value", + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=["supported_generation_methods_value"], + temperature=0.1198, + top_p=0.546, + top_k=541, + ) + ) response = await client.get_model(request) # Establish that the underlying gRPC stub method was called. @@ -657,14 +816,16 @@ async def test_get_model_async(transport: str = 'grpc_asyncio', request_type=mod # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' - assert response.base_model_id == 'base_model_id_value' - assert response.version == 'version_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.base_model_id == "base_model_id_value" + assert response.version == "version_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert response.input_token_limit == 1838 assert response.output_token_limit == 1967 - assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert response.supported_generation_methods == [ + "supported_generation_methods_value" + ] assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 @@ -684,12 +845,10 @@ def test_get_model_field_headers(): # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = model.Model() client.get_model(request) @@ -701,9 +860,9 @@ def test_get_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -716,12 +875,10 @@ async def test_get_model_field_headers_async(): # a field header. Set these to a non-empty value. request = model_service.GetModelRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) await client.get_model(request) @@ -733,9 +890,9 @@ async def test_get_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_get_model_flattened(): @@ -744,15 +901,13 @@ def test_get_model_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.get_model( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -760,7 +915,7 @@ def test_get_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val @@ -774,9 +929,10 @@ def test_get_model_flattened_error(): with pytest.raises(ValueError): client.get_model( model_service.GetModelRequest(), - name='name_value', + name="name_value", ) + @pytest.mark.asyncio async def test_get_model_flattened_async(): client = ModelServiceAsyncClient( @@ -784,9 +940,7 @@ async def test_get_model_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model.Model() @@ -794,7 +948,7 @@ async def test_get_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.get_model( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -802,9 +956,10 @@ async def test_get_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val + @pytest.mark.asyncio async def test_get_model_flattened_error_async(): client = ModelServiceAsyncClient( @@ -816,15 +971,18 @@ async def test_get_model_flattened_error_async(): with pytest.raises(ValueError): await client.get_model( model_service.GetModelRequest(), - name='name_value', + name="name_value", ) -@pytest.mark.parametrize("request_type", [ - model_service.ListModelsRequest, - dict, -]) -def test_list_models(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListModelsRequest, + dict, + ], +) +def test_list_models(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -835,12 +993,10 @@ def test_list_models(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', + next_page_token="next_page_token_value", ) response = client.list_models(request) @@ -851,7 +1007,7 @@ def test_list_models(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_models_empty_call(): @@ -859,20 +1015,21 @@ def test_list_models_empty_call(): # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: client.list_models() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListModelsRequest() + @pytest.mark.asyncio -async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListModelsRequest): +async def test_list_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest +): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -883,13 +1040,13 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_models(request) # Establish that the underlying gRPC stub method was called. @@ -899,7 +1056,7 @@ async def test_list_models_async(transport: str = 'grpc_asyncio', request_type=m # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -913,16 +1070,14 @@ def test_list_models_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.list_models( page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) # Establish that the underlying call was made with the expected @@ -933,7 +1088,7 @@ def test_list_models_flattened(): mock_val = 951 assert arg == mock_val arg = args[0].page_token - mock_val = 'page_token_value' + mock_val = "page_token_value" assert arg == mock_val @@ -948,9 +1103,10 @@ def test_list_models_flattened_error(): client.list_models( model_service.ListModelsRequest(), page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) + @pytest.mark.asyncio async def test_list_models_flattened_async(): client = ModelServiceAsyncClient( @@ -958,18 +1114,18 @@ async def test_list_models_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListModelsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.list_models( page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) # Establish that the underlying call was made with the expected @@ -980,9 +1136,10 @@ async def test_list_models_flattened_async(): mock_val = 951 assert arg == mock_val arg = args[0].page_token - mock_val = 'page_token_value' + mock_val = "page_token_value" assert arg == mock_val + @pytest.mark.asyncio async def test_list_models_flattened_error_async(): client = ModelServiceAsyncClient( @@ -995,7 +1152,7 @@ async def test_list_models_flattened_error_async(): await client.list_models( model_service.ListModelsRequest(), page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) @@ -1006,9 +1163,7 @@ def test_list_models_pager(transport_name: str = "grpc"): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( @@ -1017,17 +1172,17 @@ def test_list_models_pager(transport_name: str = "grpc"): model.Model(), model.Model(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelsResponse( models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListModelsResponse( models=[ model.Model(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelsResponse( models=[ @@ -1045,8 +1200,9 @@ def test_list_models_pager(transport_name: str = "grpc"): results = list(pager) assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) + assert all(isinstance(i, model.Model) for i in results) + + def test_list_models_pages(transport_name: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials, @@ -1054,9 +1210,7 @@ def test_list_models_pages(transport_name: str = "grpc"): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_models), - '__call__') as call: + with mock.patch.object(type(client.transport.list_models), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( @@ -1065,17 +1219,17 @@ def test_list_models_pages(transport_name: str = "grpc"): model.Model(), model.Model(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelsResponse( models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListModelsResponse( models=[ model.Model(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelsResponse( models=[ @@ -1086,9 +1240,10 @@ def test_list_models_pages(transport_name: str = "grpc"): RuntimeError, ) pages = list(client.list_models(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_models_async_pager(): client = ModelServiceAsyncClient( @@ -1097,8 +1252,8 @@ async def test_list_models_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( @@ -1107,17 +1262,17 @@ async def test_list_models_async_pager(): model.Model(), model.Model(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelsResponse( models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListModelsResponse( models=[ model.Model(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelsResponse( models=[ @@ -1127,15 +1282,16 @@ async def test_list_models_async_pager(): ), RuntimeError, ) - async_pager = await client.list_models(request={},) - assert async_pager.next_page_token == 'abc' + async_pager = await client.list_models( + request={}, + ) + assert async_pager.next_page_token == "abc" responses = [] - async for response in async_pager: # pragma: no branch + async for response in async_pager: # pragma: no branch responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, model.Model) - for i in responses) + assert all(isinstance(i, model.Model) for i in responses) @pytest.mark.asyncio @@ -1146,8 +1302,8 @@ async def test_list_models_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListModelsResponse( @@ -1156,17 +1312,17 @@ async def test_list_models_async_pages(): model.Model(), model.Model(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelsResponse( models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListModelsResponse( models=[ model.Model(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelsResponse( models=[ @@ -1179,18 +1335,22 @@ async def test_list_models_async_pages(): pages = [] # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch + async for page_ in ( # pragma: no branch await client.list_models(request={}) ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -@pytest.mark.parametrize("request_type", [ - model_service.GetTunedModelRequest, - dict, -]) -def test_get_tuned_model(request_type, transport: str = 'grpc'): + +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetTunedModelRequest, + dict, + ], +) +def test_get_tuned_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1201,19 +1361,17 @@ def test_get_tuned_model(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_tuned_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = tuned_model.TunedModel( - name='name_value', - display_name='display_name_value', - description='description_value', + name="name_value", + display_name="display_name_value", + description="description_value", temperature=0.1198, top_p=0.546, top_k=541, state=tuned_model.TunedModel.State.CREATING, - base_model='base_model_value', + base_model="base_model_value", ) response = client.get_tuned_model(request) @@ -1224,9 +1382,9 @@ def test_get_tuned_model(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, tuned_model.TunedModel) - assert response.name == 'name_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 @@ -1238,20 +1396,21 @@ def test_get_tuned_model_empty_call(): # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_tuned_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: client.get_tuned_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.GetTunedModelRequest() + @pytest.mark.asyncio -async def test_get_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.GetTunedModelRequest): +async def test_get_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetTunedModelRequest +): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1262,19 +1421,19 @@ async def test_get_tuned_model_async(transport: str = 'grpc_asyncio', request_ty request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_tuned_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(tuned_model.TunedModel( - name='name_value', - display_name='display_name_value', - description='description_value', - temperature=0.1198, - top_p=0.546, - top_k=541, - state=tuned_model.TunedModel.State.CREATING, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + ) + ) response = await client.get_tuned_model(request) # Establish that the underlying gRPC stub method was called. @@ -1284,9 +1443,9 @@ async def test_get_tuned_model_async(transport: str = 'grpc_asyncio', request_ty # Establish that the response is the type that we expect. assert isinstance(response, tuned_model.TunedModel) - assert response.name == 'name_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 @@ -1307,12 +1466,10 @@ def test_get_tuned_model_field_headers(): # a field header. Set these to a non-empty value. request = model_service.GetTunedModelRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_tuned_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: call.return_value = tuned_model.TunedModel() client.get_tuned_model(request) @@ -1324,9 +1481,9 @@ def test_get_tuned_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1339,13 +1496,13 @@ async def test_get_tuned_model_field_headers_async(): # a field header. Set these to a non-empty value. request = model_service.GetTunedModelRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_tuned_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(tuned_model.TunedModel()) + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tuned_model.TunedModel() + ) await client.get_tuned_model(request) # Establish that the underlying gRPC stub method was called. @@ -1356,9 +1513,9 @@ async def test_get_tuned_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_get_tuned_model_flattened(): @@ -1367,15 +1524,13 @@ def test_get_tuned_model_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_tuned_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = tuned_model.TunedModel() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.get_tuned_model( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -1383,7 +1538,7 @@ def test_get_tuned_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val @@ -1397,9 +1552,10 @@ def test_get_tuned_model_flattened_error(): with pytest.raises(ValueError): client.get_tuned_model( model_service.GetTunedModelRequest(), - name='name_value', + name="name_value", ) + @pytest.mark.asyncio async def test_get_tuned_model_flattened_async(): client = ModelServiceAsyncClient( @@ -1407,17 +1563,17 @@ async def test_get_tuned_model_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_tuned_model), - '__call__') as call: + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = tuned_model.TunedModel() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(tuned_model.TunedModel()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tuned_model.TunedModel() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.get_tuned_model( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -1425,9 +1581,10 @@ async def test_get_tuned_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val + @pytest.mark.asyncio async def test_get_tuned_model_flattened_error_async(): client = ModelServiceAsyncClient( @@ -1439,15 +1596,18 @@ async def test_get_tuned_model_flattened_error_async(): with pytest.raises(ValueError): await client.get_tuned_model( model_service.GetTunedModelRequest(), - name='name_value', + name="name_value", ) -@pytest.mark.parametrize("request_type", [ - model_service.ListTunedModelsRequest, - dict, -]) -def test_list_tuned_models(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListTunedModelsRequest, + dict, + ], +) +def test_list_tuned_models(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1459,11 +1619,11 @@ def test_list_tuned_models(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__') as call: + type(client.transport.list_tuned_models), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListTunedModelsResponse( - next_page_token='next_page_token_value', + next_page_token="next_page_token_value", ) response = client.list_tuned_models(request) @@ -1474,7 +1634,7 @@ def test_list_tuned_models(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTunedModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_tuned_models_empty_call(): @@ -1482,20 +1642,23 @@ def test_list_tuned_models_empty_call(): # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__') as call: + type(client.transport.list_tuned_models), "__call__" + ) as call: client.list_tuned_models() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.ListTunedModelsRequest() + @pytest.mark.asyncio -async def test_list_tuned_models_async(transport: str = 'grpc_asyncio', request_type=model_service.ListTunedModelsRequest): +async def test_list_tuned_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListTunedModelsRequest +): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1507,12 +1670,14 @@ async def test_list_tuned_models_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__') as call: + type(client.transport.list_tuned_models), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListTunedModelsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListTunedModelsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_tuned_models(request) # Establish that the underlying gRPC stub method was called. @@ -1522,7 +1687,7 @@ async def test_list_tuned_models_async(transport: str = 'grpc_asyncio', request_ # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTunedModelsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1537,15 +1702,15 @@ def test_list_tuned_models_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__') as call: + type(client.transport.list_tuned_models), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListTunedModelsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.list_tuned_models( page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) # Establish that the underlying call was made with the expected @@ -1556,7 +1721,7 @@ def test_list_tuned_models_flattened(): mock_val = 951 assert arg == mock_val arg = args[0].page_token - mock_val = 'page_token_value' + mock_val = "page_token_value" assert arg == mock_val @@ -1571,9 +1736,10 @@ def test_list_tuned_models_flattened_error(): client.list_tuned_models( model_service.ListTunedModelsRequest(), page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) + @pytest.mark.asyncio async def test_list_tuned_models_flattened_async(): client = ModelServiceAsyncClient( @@ -1582,17 +1748,19 @@ async def test_list_tuned_models_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__') as call: + type(client.transport.list_tuned_models), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = model_service.ListTunedModelsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model_service.ListTunedModelsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListTunedModelsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.list_tuned_models( page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) # Establish that the underlying call was made with the expected @@ -1603,9 +1771,10 @@ async def test_list_tuned_models_flattened_async(): mock_val = 951 assert arg == mock_val arg = args[0].page_token - mock_val = 'page_token_value' + mock_val = "page_token_value" assert arg == mock_val + @pytest.mark.asyncio async def test_list_tuned_models_flattened_error_async(): client = ModelServiceAsyncClient( @@ -1618,7 +1787,7 @@ async def test_list_tuned_models_flattened_error_async(): await client.list_tuned_models( model_service.ListTunedModelsRequest(), page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) @@ -1630,8 +1799,8 @@ def test_list_tuned_models_pager(transport_name: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__') as call: + type(client.transport.list_tuned_models), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListTunedModelsResponse( @@ -1640,17 +1809,17 @@ def test_list_tuned_models_pager(transport_name: str = "grpc"): tuned_model.TunedModel(), tuned_model.TunedModel(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListTunedModelsResponse( tuned_models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListTunedModelsResponse( tuned_models=[ tuned_model.TunedModel(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListTunedModelsResponse( tuned_models=[ @@ -1668,8 +1837,9 @@ def test_list_tuned_models_pager(transport_name: str = "grpc"): results = list(pager) assert len(results) == 6 - assert all(isinstance(i, tuned_model.TunedModel) - for i in results) + assert all(isinstance(i, tuned_model.TunedModel) for i in results) + + def test_list_tuned_models_pages(transport_name: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials, @@ -1678,8 +1848,8 @@ def test_list_tuned_models_pages(transport_name: str = "grpc"): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__') as call: + type(client.transport.list_tuned_models), "__call__" + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListTunedModelsResponse( @@ -1688,17 +1858,17 @@ def test_list_tuned_models_pages(transport_name: str = "grpc"): tuned_model.TunedModel(), tuned_model.TunedModel(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListTunedModelsResponse( tuned_models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListTunedModelsResponse( tuned_models=[ tuned_model.TunedModel(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListTunedModelsResponse( tuned_models=[ @@ -1709,9 +1879,10 @@ def test_list_tuned_models_pages(transport_name: str = "grpc"): RuntimeError, ) pages = list(client.list_tuned_models(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_tuned_models_async_pager(): client = ModelServiceAsyncClient( @@ -1720,8 +1891,10 @@ async def test_list_tuned_models_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_tuned_models), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListTunedModelsResponse( @@ -1730,17 +1903,17 @@ async def test_list_tuned_models_async_pager(): tuned_model.TunedModel(), tuned_model.TunedModel(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListTunedModelsResponse( tuned_models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListTunedModelsResponse( tuned_models=[ tuned_model.TunedModel(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListTunedModelsResponse( tuned_models=[ @@ -1750,15 +1923,16 @@ async def test_list_tuned_models_async_pager(): ), RuntimeError, ) - async_pager = await client.list_tuned_models(request={},) - assert async_pager.next_page_token == 'abc' + async_pager = await client.list_tuned_models( + request={}, + ) + assert async_pager.next_page_token == "abc" responses = [] - async for response in async_pager: # pragma: no branch + async for response in async_pager: # pragma: no branch responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, tuned_model.TunedModel) - for i in responses) + assert all(isinstance(i, tuned_model.TunedModel) for i in responses) @pytest.mark.asyncio @@ -1769,8 +1943,10 @@ async def test_list_tuned_models_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_tuned_models), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_tuned_models), + "__call__", + new_callable=mock.AsyncMock, + ) as call: # Set the response to a series of pages. call.side_effect = ( model_service.ListTunedModelsResponse( @@ -1779,17 +1955,17 @@ async def test_list_tuned_models_async_pages(): tuned_model.TunedModel(), tuned_model.TunedModel(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListTunedModelsResponse( tuned_models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListTunedModelsResponse( tuned_models=[ tuned_model.TunedModel(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListTunedModelsResponse( tuned_models=[ @@ -1802,18 +1978,22 @@ async def test_list_tuned_models_async_pages(): pages = [] # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch + async for page_ in ( # pragma: no branch await client.list_tuned_models(request={}) ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -@pytest.mark.parametrize("request_type", [ - model_service.CreateTunedModelRequest, - dict, -]) -def test_create_tuned_model(request_type, transport: str = 'grpc'): + +@pytest.mark.parametrize( + "request_type", + [ + model_service.CreateTunedModelRequest, + dict, + ], +) +def test_create_tuned_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1825,10 +2005,10 @@ def test_create_tuned_model(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_tuned_model), - '__call__') as call: + type(client.transport.create_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/spam') + call.return_value = operations_pb2.Operation(name="operations/spam") response = client.create_tuned_model(request) # Establish that the underlying gRPC stub method was called. @@ -1845,20 +2025,23 @@ def test_create_tuned_model_empty_call(): # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_tuned_model), - '__call__') as call: + type(client.transport.create_tuned_model), "__call__" + ) as call: client.create_tuned_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.CreateTunedModelRequest() + @pytest.mark.asyncio -async def test_create_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.CreateTunedModelRequest): +async def test_create_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.CreateTunedModelRequest +): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1870,11 +2053,11 @@ async def test_create_tuned_model_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_tuned_model), - '__call__') as call: + type(client.transport.create_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) response = await client.create_tuned_model(request) @@ -1899,15 +2082,19 @@ def test_create_tuned_model_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_tuned_model), - '__call__') as call: + type(client.transport.create_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_tuned_model( - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - tuned_model_id='tuned_model_id_value', + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", ) # Establish that the underlying call was made with the expected @@ -1915,10 +2102,14 @@ def test_create_tuned_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].tuned_model - mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) assert arg == mock_val arg = args[0].tuned_model_id - mock_val = 'tuned_model_id_value' + mock_val = "tuned_model_id_value" assert arg == mock_val @@ -1932,10 +2123,15 @@ def test_create_tuned_model_flattened_error(): with pytest.raises(ValueError): client.create_tuned_model( model_service.CreateTunedModelRequest(), - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - tuned_model_id='tuned_model_id_value', + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", ) + @pytest.mark.asyncio async def test_create_tuned_model_flattened_async(): client = ModelServiceAsyncClient( @@ -1944,19 +2140,23 @@ async def test_create_tuned_model_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_tuned_model), - '__call__') as call: + type(client.transport.create_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = operations_pb2.Operation(name='operations/op') + call.return_value = operations_pb2.Operation(name="operations/op") call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( - operations_pb2.Operation(name='operations/spam') + operations_pb2.Operation(name="operations/spam") ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_tuned_model( - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - tuned_model_id='tuned_model_id_value', + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", ) # Establish that the underlying call was made with the expected @@ -1964,12 +2164,17 @@ async def test_create_tuned_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].tuned_model - mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) assert arg == mock_val arg = args[0].tuned_model_id - mock_val = 'tuned_model_id_value' + mock_val = "tuned_model_id_value" assert arg == mock_val + @pytest.mark.asyncio async def test_create_tuned_model_flattened_error_async(): client = ModelServiceAsyncClient( @@ -1981,16 +2186,23 @@ async def test_create_tuned_model_flattened_error_async(): with pytest.raises(ValueError): await client.create_tuned_model( model_service.CreateTunedModelRequest(), - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - tuned_model_id='tuned_model_id_value', + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", ) -@pytest.mark.parametrize("request_type", [ - model_service.UpdateTunedModelRequest, - dict, -]) -def test_update_tuned_model(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + model_service.UpdateTunedModelRequest, + dict, + ], +) +def test_update_tuned_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2002,18 +2214,18 @@ def test_update_tuned_model(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_tuned_model), - '__call__') as call: + type(client.transport.update_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_tuned_model.TunedModel( - name='name_value', - display_name='display_name_value', - description='description_value', + name="name_value", + display_name="display_name_value", + description="description_value", temperature=0.1198, top_p=0.546, top_k=541, state=gag_tuned_model.TunedModel.State.CREATING, - base_model='base_model_value', + base_model="base_model_value", ) response = client.update_tuned_model(request) @@ -2024,9 +2236,9 @@ def test_update_tuned_model(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, gag_tuned_model.TunedModel) - assert response.name == 'name_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 @@ -2038,20 +2250,23 @@ def test_update_tuned_model_empty_call(): # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_tuned_model), - '__call__') as call: + type(client.transport.update_tuned_model), "__call__" + ) as call: client.update_tuned_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.UpdateTunedModelRequest() + @pytest.mark.asyncio -async def test_update_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.UpdateTunedModelRequest): +async def test_update_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UpdateTunedModelRequest +): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2063,18 +2278,20 @@ async def test_update_tuned_model_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_tuned_model), - '__call__') as call: + type(client.transport.update_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(gag_tuned_model.TunedModel( - name='name_value', - display_name='display_name_value', - description='description_value', - temperature=0.1198, - top_p=0.546, - top_k=541, - state=gag_tuned_model.TunedModel.State.CREATING, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + ) + ) response = await client.update_tuned_model(request) # Establish that the underlying gRPC stub method was called. @@ -2084,9 +2301,9 @@ async def test_update_tuned_model_async(transport: str = 'grpc_asyncio', request # Establish that the response is the type that we expect. assert isinstance(response, gag_tuned_model.TunedModel) - assert response.name == 'name_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 @@ -2107,12 +2324,12 @@ def test_update_tuned_model_field_headers(): # a field header. Set these to a non-empty value. request = model_service.UpdateTunedModelRequest() - request.tuned_model.name = 'name_value' + request.tuned_model.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_tuned_model), - '__call__') as call: + type(client.transport.update_tuned_model), "__call__" + ) as call: call.return_value = gag_tuned_model.TunedModel() client.update_tuned_model(request) @@ -2124,9 +2341,9 @@ def test_update_tuned_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'tuned_model.name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "tuned_model.name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2139,13 +2356,15 @@ async def test_update_tuned_model_field_headers_async(): # a field header. Set these to a non-empty value. request = model_service.UpdateTunedModelRequest() - request.tuned_model.name = 'name_value' + request.tuned_model.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_tuned_model), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_tuned_model.TunedModel()) + type(client.transport.update_tuned_model), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_tuned_model.TunedModel() + ) await client.update_tuned_model(request) # Establish that the underlying gRPC stub method was called. @@ -2156,9 +2375,9 @@ async def test_update_tuned_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'tuned_model.name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "tuned_model.name=name_value", + ) in kw["metadata"] def test_update_tuned_model_flattened(): @@ -2168,15 +2387,19 @@ def test_update_tuned_model_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_tuned_model), - '__call__') as call: + type(client.transport.update_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_tuned_model.TunedModel() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_tuned_model( - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2184,10 +2407,14 @@ def test_update_tuned_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].tuned_model - mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) assert arg == mock_val arg = args[0].update_mask - mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) assert arg == mock_val @@ -2201,10 +2428,15 @@ def test_update_tuned_model_flattened_error(): with pytest.raises(ValueError): client.update_tuned_model( model_service.UpdateTunedModelRequest(), - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) + @pytest.mark.asyncio async def test_update_tuned_model_flattened_async(): client = ModelServiceAsyncClient( @@ -2213,17 +2445,23 @@ async def test_update_tuned_model_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_tuned_model), - '__call__') as call: + type(client.transport.update_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_tuned_model.TunedModel() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_tuned_model.TunedModel()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_tuned_model.TunedModel() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_tuned_model( - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -2231,12 +2469,17 @@ async def test_update_tuned_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].tuned_model - mock_val = gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')) + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) assert arg == mock_val arg = args[0].update_mask - mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) assert arg == mock_val + @pytest.mark.asyncio async def test_update_tuned_model_flattened_error_async(): client = ModelServiceAsyncClient( @@ -2248,16 +2491,23 @@ async def test_update_tuned_model_flattened_error_async(): with pytest.raises(ValueError): await client.update_tuned_model( model_service.UpdateTunedModelRequest(), - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) -@pytest.mark.parametrize("request_type", [ - model_service.DeleteTunedModelRequest, - dict, -]) -def test_delete_tuned_model(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + model_service.DeleteTunedModelRequest, + dict, + ], +) +def test_delete_tuned_model(request_type, transport: str = "grpc"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2269,8 +2519,8 @@ def test_delete_tuned_model(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_tuned_model), - '__call__') as call: + type(client.transport.delete_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None response = client.delete_tuned_model(request) @@ -2289,20 +2539,23 @@ def test_delete_tuned_model_empty_call(): # i.e. request == None and no flattened fields passed, work. client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_tuned_model), - '__call__') as call: + type(client.transport.delete_tuned_model), "__call__" + ) as call: client.delete_tuned_model() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == model_service.DeleteTunedModelRequest() + @pytest.mark.asyncio -async def test_delete_tuned_model_async(transport: str = 'grpc_asyncio', request_type=model_service.DeleteTunedModelRequest): +async def test_delete_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteTunedModelRequest +): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2314,8 +2567,8 @@ async def test_delete_tuned_model_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_tuned_model), - '__call__') as call: + type(client.transport.delete_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) response = await client.delete_tuned_model(request) @@ -2343,12 +2596,12 @@ def test_delete_tuned_model_field_headers(): # a field header. Set these to a non-empty value. request = model_service.DeleteTunedModelRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_tuned_model), - '__call__') as call: + type(client.transport.delete_tuned_model), "__call__" + ) as call: call.return_value = None client.delete_tuned_model(request) @@ -2360,9 +2613,9 @@ def test_delete_tuned_model_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2375,12 +2628,12 @@ async def test_delete_tuned_model_field_headers_async(): # a field header. Set these to a non-empty value. request = model_service.DeleteTunedModelRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_tuned_model), - '__call__') as call: + type(client.transport.delete_tuned_model), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.delete_tuned_model(request) @@ -2392,9 +2645,9 @@ async def test_delete_tuned_model_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_delete_tuned_model_flattened(): @@ -2404,14 +2657,14 @@ def test_delete_tuned_model_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_tuned_model), - '__call__') as call: + type(client.transport.delete_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.delete_tuned_model( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -2419,7 +2672,7 @@ def test_delete_tuned_model_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val @@ -2433,9 +2686,10 @@ def test_delete_tuned_model_flattened_error(): with pytest.raises(ValueError): client.delete_tuned_model( model_service.DeleteTunedModelRequest(), - name='name_value', + name="name_value", ) + @pytest.mark.asyncio async def test_delete_tuned_model_flattened_async(): client = ModelServiceAsyncClient( @@ -2444,8 +2698,8 @@ async def test_delete_tuned_model_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_tuned_model), - '__call__') as call: + type(client.transport.delete_tuned_model), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -2453,7 +2707,7 @@ async def test_delete_tuned_model_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.delete_tuned_model( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -2461,9 +2715,10 @@ async def test_delete_tuned_model_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val + @pytest.mark.asyncio async def test_delete_tuned_model_flattened_error_async(): client = ModelServiceAsyncClient( @@ -2475,14 +2730,17 @@ async def test_delete_tuned_model_flattened_error_async(): with pytest.raises(ValueError): await client.delete_tuned_model( model_service.DeleteTunedModelRequest(), - name='name_value', + name="name_value", ) -@pytest.mark.parametrize("request_type", [ - model_service.GetModelRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetModelRequest, + dict, + ], +) def test_get_model_rest(request_type): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2490,24 +2748,24 @@ def test_get_model_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'name': 'models/sample1'} + request_init = {"name": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = model.Model( - name='name_value', - base_model_id='base_model_id_value', - version='version_value', - display_name='display_name_value', - description='description_value', - input_token_limit=1838, - output_token_limit=1967, - supported_generation_methods=['supported_generation_methods_value'], - temperature=0.1198, - top_p=0.546, - top_k=541, + name="name_value", + base_model_id="base_model_id_value", + version="version_value", + display_name="display_name_value", + description="description_value", + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=["supported_generation_methods_value"], + temperature=0.1198, + top_p=0.546, + top_k=541, ) # Wrap the value into a proper Response obj @@ -2516,20 +2774,22 @@ def test_get_model_rest(request_type): pb_return_value = model.Model.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_model(request) # Establish that the response is the type that we expect. assert isinstance(response, model.Model) - assert response.name == 'name_value' - assert response.base_model_id == 'base_model_id_value' - assert response.version == 'version_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.base_model_id == "base_model_id_value" + assert response.version == "version_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert response.input_token_limit == 1838 assert response.output_token_limit == 1967 - assert response.supported_generation_methods == ['supported_generation_methods_value'] + assert response.supported_generation_methods == [ + "supported_generation_methods_value" + ] assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 @@ -2542,49 +2802,55 @@ def test_get_model_rest_required_fields(request_type=model_service.GetModelReque request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = 'name_value' + jsonified_request["name"] = "name_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "name" in jsonified_request - assert jsonified_request["name"] == 'name_value' + assert jsonified_request["name"] == "name_value" client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = model.Model() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "get", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, } transcode.return_value = transcode_result @@ -2594,36 +2860,43 @@ def test_get_model_rest_required_fields(request_type=model_service.GetModelReque pb_return_value = model.Model.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_model(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_get_model_rest_unset_required_fields(): - transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.get_model._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name", ))) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_get_model_rest_interceptors(null_interceptor): transport = transports.ModelServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_get_model") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_get_model") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_get_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_get_model" + ) as pre: pre.assert_not_called() post.assert_not_called() pb_message = model_service.GetModelRequest.pb(model_service.GetModelRequest()) @@ -2640,31 +2913,41 @@ def test_get_model_rest_interceptors(null_interceptor): req.return_value._content = model.Model.to_json(model.Model()) request = model_service.GetModelRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = model.Model() - client.get_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.get_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_get_model_rest_bad_request(transport: str = 'rest', request_type=model_service.GetModelRequest): +def test_get_model_rest_bad_request( + transport: str = "rest", request_type=model_service.GetModelRequest +): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'name': 'models/sample1'} + request_init = {"name": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -2680,16 +2963,16 @@ def test_get_model_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = model.Model() # get arguments that satisfy an http rule for this method - sample_request = {'name': 'models/sample1'} + sample_request = {"name": "models/sample1"} # get truthy value for each flattened field mock_args = dict( - name='name_value', + name="name_value", ) mock_args.update(sample_request) @@ -2698,7 +2981,7 @@ def test_get_model_rest_flattened(): response_value.status_code = 200 pb_return_value = model.Model.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.get_model(**mock_args) @@ -2707,10 +2990,12 @@ def test_get_model_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{name=models/*}" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{name=models/*}" % client.transport._host, args[1] + ) -def test_get_model_rest_flattened_error(transport: str = 'rest'): +def test_get_model_rest_flattened_error(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2721,21 +3006,23 @@ def test_get_model_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.get_model( model_service.GetModelRequest(), - name='name_value', + name="name_value", ) def test_get_model_rest_error(): client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - model_service.ListModelsRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListModelsRequest, + dict, + ], +) def test_list_models_rest(request_type): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2747,10 +3034,10 @@ def test_list_models_rest(request_type): request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = model_service.ListModelsResponse( - next_page_token='next_page_token_value', + next_page_token="next_page_token_value", ) # Wrap the value into a proper Response obj @@ -2759,29 +3046,38 @@ def test_list_models_rest(request_type): pb_return_value = model_service.ListModelsResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_models(request) # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.parametrize("null_interceptor", [True, False]) def test_list_models_rest_interceptors(null_interceptor): transport = transports.ModelServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_list_models") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_list_models") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_list_models" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_list_models" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = model_service.ListModelsRequest.pb(model_service.ListModelsRequest()) + pb_message = model_service.ListModelsRequest.pb( + model_service.ListModelsRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -2792,23 +3088,33 @@ def test_list_models_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = model_service.ListModelsResponse.to_json(model_service.ListModelsResponse()) + req.return_value._content = model_service.ListModelsResponse.to_json( + model_service.ListModelsResponse() + ) request = model_service.ListModelsRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = model_service.ListModelsResponse() - client.list_models(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.list_models( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_list_models_rest_bad_request(transport: str = 'rest', request_type=model_service.ListModelsRequest): +def test_list_models_rest_bad_request( + transport: str = "rest", request_type=model_service.ListModelsRequest +): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2819,7 +3125,9 @@ def test_list_models_rest_bad_request(transport: str = 'rest', request_type=mode request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -2835,7 +3143,7 @@ def test_list_models_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = model_service.ListModelsResponse() @@ -2845,7 +3153,7 @@ def test_list_models_rest_flattened(): # get truthy value for each flattened field mock_args = dict( page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) mock_args.update(sample_request) @@ -2854,7 +3162,7 @@ def test_list_models_rest_flattened(): response_value.status_code = 200 pb_return_value = model_service.ListModelsResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.list_models(**mock_args) @@ -2863,10 +3171,12 @@ def test_list_models_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/models" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/models" % client.transport._host, args[1] + ) -def test_list_models_rest_flattened_error(transport: str = 'rest'): +def test_list_models_rest_flattened_error(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2878,20 +3188,20 @@ def test_list_models_rest_flattened_error(transport: str = 'rest'): client.list_models( model_service.ListModelsRequest(), page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) -def test_list_models_rest_pager(transport: str = 'rest'): +def test_list_models_rest_pager(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # TODO(kbandes): remove this mock unless there's a good reason for it. - #with mock.patch.object(path_template, 'transcode') as transcode: + # with mock.patch.object(path_template, 'transcode') as transcode: # Set the response as a series of pages response = ( model_service.ListModelsResponse( @@ -2900,17 +3210,17 @@ def test_list_models_rest_pager(transport: str = 'rest'): model.Model(), model.Model(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListModelsResponse( models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListModelsResponse( models=[ model.Model(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListModelsResponse( models=[ @@ -2926,7 +3236,7 @@ def test_list_models_rest_pager(transport: str = 'rest'): response = tuple(model_service.ListModelsResponse.to_json(x) for x in response) return_values = tuple(Response() for i in response) for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode('UTF-8') + return_val._content = response_val.encode("UTF-8") return_val.status_code = 200 req.side_effect = return_values @@ -2936,18 +3246,20 @@ def test_list_models_rest_pager(transport: str = 'rest'): results = list(pager) assert len(results) == 6 - assert all(isinstance(i, model.Model) - for i in results) + assert all(isinstance(i, model.Model) for i in results) pages = list(client.list_models(request=sample_request).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -@pytest.mark.parametrize("request_type", [ - model_service.GetTunedModelRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetTunedModelRequest, + dict, + ], +) def test_get_tuned_model_rest(request_type): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2955,21 +3267,21 @@ def test_get_tuned_model_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1'} + request_init = {"name": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = tuned_model.TunedModel( - name='name_value', - display_name='display_name_value', - description='description_value', - temperature=0.1198, - top_p=0.546, - top_k=541, - state=tuned_model.TunedModel.State.CREATING, - base_model='base_model_value', + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + base_model="base_model_value", ) # Wrap the value into a proper Response obj @@ -2978,71 +3290,79 @@ def test_get_tuned_model_rest(request_type): pb_return_value = tuned_model.TunedModel.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_tuned_model(request) # Establish that the response is the type that we expect. assert isinstance(response, tuned_model.TunedModel) - assert response.name == 'name_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 assert response.state == tuned_model.TunedModel.State.CREATING -def test_get_tuned_model_rest_required_fields(request_type=model_service.GetTunedModelRequest): +def test_get_tuned_model_rest_required_fields( + request_type=model_service.GetTunedModelRequest, +): transport_class = transports.ModelServiceRestTransport request_init = {} request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_tuned_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = 'name_value' + jsonified_request["name"] = "name_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_tuned_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "name" in jsonified_request - assert jsonified_request["name"] == 'name_value' + assert jsonified_request["name"] == "name_value" client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = tuned_model.TunedModel() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "get", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, } transcode.return_value = transcode_result @@ -3052,39 +3372,48 @@ def test_get_tuned_model_rest_required_fields(request_type=model_service.GetTune pb_return_value = tuned_model.TunedModel.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_tuned_model(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_get_tuned_model_rest_unset_required_fields(): - transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.get_tuned_model._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name", ))) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_get_tuned_model_rest_interceptors(null_interceptor): transport = transports.ModelServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_get_tuned_model") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_get_tuned_model") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_get_tuned_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_get_tuned_model" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = model_service.GetTunedModelRequest.pb(model_service.GetTunedModelRequest()) + pb_message = model_service.GetTunedModelRequest.pb( + model_service.GetTunedModelRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -3095,34 +3424,46 @@ def test_get_tuned_model_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = tuned_model.TunedModel.to_json(tuned_model.TunedModel()) + req.return_value._content = tuned_model.TunedModel.to_json( + tuned_model.TunedModel() + ) request = model_service.GetTunedModelRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = tuned_model.TunedModel() - client.get_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.get_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_get_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.GetTunedModelRequest): +def test_get_tuned_model_rest_bad_request( + transport: str = "rest", request_type=model_service.GetTunedModelRequest +): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1'} + request_init = {"name": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -3138,16 +3479,16 @@ def test_get_tuned_model_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = tuned_model.TunedModel() # get arguments that satisfy an http rule for this method - sample_request = {'name': 'tunedModels/sample1'} + sample_request = {"name": "tunedModels/sample1"} # get truthy value for each flattened field mock_args = dict( - name='name_value', + name="name_value", ) mock_args.update(sample_request) @@ -3156,7 +3497,7 @@ def test_get_tuned_model_rest_flattened(): response_value.status_code = 200 pb_return_value = tuned_model.TunedModel.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.get_tuned_model(**mock_args) @@ -3165,10 +3506,12 @@ def test_get_tuned_model_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{name=tunedModels/*}" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{name=tunedModels/*}" % client.transport._host, args[1] + ) -def test_get_tuned_model_rest_flattened_error(transport: str = 'rest'): +def test_get_tuned_model_rest_flattened_error(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3179,21 +3522,23 @@ def test_get_tuned_model_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.get_tuned_model( model_service.GetTunedModelRequest(), - name='name_value', + name="name_value", ) def test_get_tuned_model_rest_error(): client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - model_service.ListTunedModelsRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListTunedModelsRequest, + dict, + ], +) def test_list_tuned_models_rest(request_type): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3205,10 +3550,10 @@ def test_list_tuned_models_rest(request_type): request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = model_service.ListTunedModelsResponse( - next_page_token='next_page_token_value', + next_page_token="next_page_token_value", ) # Wrap the value into a proper Response obj @@ -3217,29 +3562,38 @@ def test_list_tuned_models_rest(request_type): pb_return_value = model_service.ListTunedModelsResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_tuned_models(request) # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListTunedModelsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.parametrize("null_interceptor", [True, False]) def test_list_tuned_models_rest_interceptors(null_interceptor): transport = transports.ModelServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_list_tuned_models") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_list_tuned_models") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_list_tuned_models" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_list_tuned_models" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = model_service.ListTunedModelsRequest.pb(model_service.ListTunedModelsRequest()) + pb_message = model_service.ListTunedModelsRequest.pb( + model_service.ListTunedModelsRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -3250,23 +3604,33 @@ def test_list_tuned_models_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = model_service.ListTunedModelsResponse.to_json(model_service.ListTunedModelsResponse()) + req.return_value._content = model_service.ListTunedModelsResponse.to_json( + model_service.ListTunedModelsResponse() + ) request = model_service.ListTunedModelsRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = model_service.ListTunedModelsResponse() - client.list_tuned_models(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.list_tuned_models( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_list_tuned_models_rest_bad_request(transport: str = 'rest', request_type=model_service.ListTunedModelsRequest): +def test_list_tuned_models_rest_bad_request( + transport: str = "rest", request_type=model_service.ListTunedModelsRequest +): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3277,7 +3641,9 @@ def test_list_tuned_models_rest_bad_request(transport: str = 'rest', request_typ request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -3293,7 +3659,7 @@ def test_list_tuned_models_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = model_service.ListTunedModelsResponse() @@ -3303,7 +3669,7 @@ def test_list_tuned_models_rest_flattened(): # get truthy value for each flattened field mock_args = dict( page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) mock_args.update(sample_request) @@ -3312,7 +3678,7 @@ def test_list_tuned_models_rest_flattened(): response_value.status_code = 200 pb_return_value = model_service.ListTunedModelsResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.list_tuned_models(**mock_args) @@ -3321,10 +3687,12 @@ def test_list_tuned_models_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/tunedModels" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/tunedModels" % client.transport._host, args[1] + ) -def test_list_tuned_models_rest_flattened_error(transport: str = 'rest'): +def test_list_tuned_models_rest_flattened_error(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3336,20 +3704,20 @@ def test_list_tuned_models_rest_flattened_error(transport: str = 'rest'): client.list_tuned_models( model_service.ListTunedModelsRequest(), page_size=951, - page_token='page_token_value', + page_token="page_token_value", ) -def test_list_tuned_models_rest_pager(transport: str = 'rest'): +def test_list_tuned_models_rest_pager(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # TODO(kbandes): remove this mock unless there's a good reason for it. - #with mock.patch.object(path_template, 'transcode') as transcode: + # with mock.patch.object(path_template, 'transcode') as transcode: # Set the response as a series of pages response = ( model_service.ListTunedModelsResponse( @@ -3358,17 +3726,17 @@ def test_list_tuned_models_rest_pager(transport: str = 'rest'): tuned_model.TunedModel(), tuned_model.TunedModel(), ], - next_page_token='abc', + next_page_token="abc", ), model_service.ListTunedModelsResponse( tuned_models=[], - next_page_token='def', + next_page_token="def", ), model_service.ListTunedModelsResponse( tuned_models=[ tuned_model.TunedModel(), ], - next_page_token='ghi', + next_page_token="ghi", ), model_service.ListTunedModelsResponse( tuned_models=[ @@ -3381,10 +3749,12 @@ def test_list_tuned_models_rest_pager(transport: str = 'rest'): response = response + response # Wrap the values into proper Response objs - response = tuple(model_service.ListTunedModelsResponse.to_json(x) for x in response) + response = tuple( + model_service.ListTunedModelsResponse.to_json(x) for x in response + ) return_values = tuple(Response() for i in response) for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode('UTF-8') + return_val._content = response_val.encode("UTF-8") return_val.status_code = 200 req.side_effect = return_values @@ -3394,18 +3764,20 @@ def test_list_tuned_models_rest_pager(transport: str = 'rest'): results = list(pager) assert len(results) == 6 - assert all(isinstance(i, tuned_model.TunedModel) - for i in results) + assert all(isinstance(i, tuned_model.TunedModel) for i in results) pages = list(client.list_tuned_models(request=sample_request).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -@pytest.mark.parametrize("request_type", [ - model_service.CreateTunedModelRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + model_service.CreateTunedModelRequest, + dict, + ], +) def test_create_tuned_model_rest(request_type): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3414,20 +3786,54 @@ def test_create_tuned_model_rest(request_type): # send a request that will satisfy transcoding request_init = {} - request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'name_value', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request_init["tuned_model"] = { + "tuned_model_source": { + "tuned_model": "tuned_model_value", + "base_model": "base_model_value", + }, + "base_model": "base_model_value", + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "temperature": 0.1198, + "top_p": 0.546, + "top_k": 541, + "state": 1, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "tuning_task": { + "start_time": {}, + "complete_time": {}, + "snapshots": [ + {"step": 444, "epoch": 527, "mean_loss": 0.961, "compute_time": {}} + ], + "training_data": { + "examples": { + "examples": [ + {"text_input": "text_input_value", "output": "output_value"} + ] + } + }, + "hyperparameters": { + "epoch_count": 1175, + "batch_size": 1052, + "learning_rate": 0.1371, + }, + }, + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name='operations/spam') + return_value = operations_pb2.Operation(name="operations/spam") # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.create_tuned_model(request) @@ -3435,95 +3841,113 @@ def test_create_tuned_model_rest(request_type): assert response.operation.name == "operations/spam" -def test_create_tuned_model_rest_required_fields(request_type=model_service.CreateTunedModelRequest): +def test_create_tuned_model_rest_required_fields( + request_type=model_service.CreateTunedModelRequest, +): transport_class = transports.ModelServiceRestTransport request_init = {} request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_tuned_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_tuned_model._get_unset_required_fields(jsonified_request) # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("tuned_model_id", )) + assert not set(unset_fields) - set(("tuned_model_id",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name='operations/spam') + return_value = operations_pb2.Operation(name="operations/spam") # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.create_tuned_model(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_create_tuned_model_rest_unset_required_fields(): - transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.create_tuned_model._get_unset_required_fields({}) - assert set(unset_fields) == (set(("tunedModelId", )) & set(("tunedModel", ))) + assert set(unset_fields) == (set(("tunedModelId",)) & set(("tunedModel",))) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_create_tuned_model_rest_interceptors(null_interceptor): transport = transports.ModelServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(operation.Operation, "_set_result_from_operation"), \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_create_tuned_model") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_create_tuned_model") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ModelServiceRestInterceptor, "post_create_tuned_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_create_tuned_model" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = model_service.CreateTunedModelRequest.pb(model_service.CreateTunedModelRequest()) + pb_message = model_service.CreateTunedModelRequest.pb( + model_service.CreateTunedModelRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -3534,23 +3958,33 @@ def test_create_tuned_model_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = json_format.MessageToJson(operations_pb2.Operation()) + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) request = model_service.CreateTunedModelRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = operations_pb2.Operation() - client.create_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.create_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_create_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.CreateTunedModelRequest): +def test_create_tuned_model_rest_bad_request( + transport: str = "rest", request_type=model_service.CreateTunedModelRequest +): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3558,11 +3992,47 @@ def test_create_tuned_model_rest_bad_request(transport: str = 'rest', request_ty # send a request that will satisfy transcoding request_init = {} - request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'name_value', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request_init["tuned_model"] = { + "tuned_model_source": { + "tuned_model": "tuned_model_value", + "base_model": "base_model_value", + }, + "base_model": "base_model_value", + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "temperature": 0.1198, + "top_p": 0.546, + "top_k": 541, + "state": 1, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "tuning_task": { + "start_time": {}, + "complete_time": {}, + "snapshots": [ + {"step": 444, "epoch": 527, "mean_loss": 0.961, "compute_time": {}} + ], + "training_data": { + "examples": { + "examples": [ + {"text_input": "text_input_value", "output": "output_value"} + ] + } + }, + "hyperparameters": { + "epoch_count": 1175, + "batch_size": 1052, + "learning_rate": 0.1371, + }, + }, + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -3578,17 +4048,21 @@ def test_create_tuned_model_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = operations_pb2.Operation(name='operations/spam') + return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method sample_request = {} # get truthy value for each flattened field mock_args = dict( - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - tuned_model_id='tuned_model_id_value', + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", ) mock_args.update(sample_request) @@ -3596,7 +4070,7 @@ def test_create_tuned_model_rest_flattened(): response_value = Response() response_value.status_code = 200 json_return_value = json_format.MessageToJson(return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.create_tuned_model(**mock_args) @@ -3605,10 +4079,12 @@ def test_create_tuned_model_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/tunedModels" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/tunedModels" % client.transport._host, args[1] + ) -def test_create_tuned_model_rest_flattened_error(transport: str = 'rest'): +def test_create_tuned_model_rest_flattened_error(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3619,22 +4095,28 @@ def test_create_tuned_model_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.create_tuned_model( model_service.CreateTunedModelRequest(), - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - tuned_model_id='tuned_model_id_value', + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", ) def test_create_tuned_model_rest_error(): client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - model_service.UpdateTunedModelRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + model_service.UpdateTunedModelRequest, + dict, + ], +) def test_update_tuned_model_rest(request_type): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3642,22 +4124,56 @@ def test_update_tuned_model_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'tuned_model': {'name': 'tunedModels/sample1'}} - request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'tunedModels/sample1', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request_init = {"tuned_model": {"name": "tunedModels/sample1"}} + request_init["tuned_model"] = { + "tuned_model_source": { + "tuned_model": "tuned_model_value", + "base_model": "base_model_value", + }, + "base_model": "base_model_value", + "name": "tunedModels/sample1", + "display_name": "display_name_value", + "description": "description_value", + "temperature": 0.1198, + "top_p": 0.546, + "top_k": 541, + "state": 1, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "tuning_task": { + "start_time": {}, + "complete_time": {}, + "snapshots": [ + {"step": 444, "epoch": 527, "mean_loss": 0.961, "compute_time": {}} + ], + "training_data": { + "examples": { + "examples": [ + {"text_input": "text_input_value", "output": "output_value"} + ] + } + }, + "hyperparameters": { + "epoch_count": 1175, + "batch_size": 1052, + "learning_rate": 0.1371, + }, + }, + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = gag_tuned_model.TunedModel( - name='name_value', - display_name='display_name_value', - description='description_value', - temperature=0.1198, - top_p=0.546, - top_k=541, - state=gag_tuned_model.TunedModel.State.CREATING, - base_model='base_model_value', + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + base_model="base_model_value", ) # Wrap the value into a proper Response obj @@ -3666,70 +4182,78 @@ def test_update_tuned_model_rest(request_type): pb_return_value = gag_tuned_model.TunedModel.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.update_tuned_model(request) # Establish that the response is the type that we expect. assert isinstance(response, gag_tuned_model.TunedModel) - assert response.name == 'name_value' - assert response.display_name == 'display_name_value' - assert response.description == 'description_value' + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) assert response.top_k == 541 assert response.state == gag_tuned_model.TunedModel.State.CREATING -def test_update_tuned_model_rest_required_fields(request_type=model_service.UpdateTunedModelRequest): +def test_update_tuned_model_rest_required_fields( + request_type=model_service.UpdateTunedModelRequest, +): transport_class = transports.ModelServiceRestTransport request_init = {} request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_tuned_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_tuned_model._get_unset_required_fields(jsonified_request) # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("update_mask", )) + assert not set(unset_fields) - set(("update_mask",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = gag_tuned_model.TunedModel() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "patch", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -3738,39 +4262,56 @@ def test_update_tuned_model_rest_required_fields(request_type=model_service.Upda pb_return_value = gag_tuned_model.TunedModel.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.update_tuned_model(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_update_tuned_model_rest_unset_required_fields(): - transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.update_tuned_model._get_unset_required_fields({}) - assert set(unset_fields) == (set(("updateMask", )) & set(("tunedModel", "updateMask", ))) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "tunedModel", + "updateMask", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_update_tuned_model_rest_interceptors(null_interceptor): transport = transports.ModelServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "post_update_tuned_model") as post, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_update_tuned_model") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_update_tuned_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_update_tuned_model" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = model_service.UpdateTunedModelRequest.pb(model_service.UpdateTunedModelRequest()) + pb_message = model_service.UpdateTunedModelRequest.pb( + model_service.UpdateTunedModelRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -3781,35 +4322,81 @@ def test_update_tuned_model_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = gag_tuned_model.TunedModel.to_json(gag_tuned_model.TunedModel()) + req.return_value._content = gag_tuned_model.TunedModel.to_json( + gag_tuned_model.TunedModel() + ) request = model_service.UpdateTunedModelRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = gag_tuned_model.TunedModel() - client.update_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.update_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_update_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.UpdateTunedModelRequest): +def test_update_tuned_model_rest_bad_request( + transport: str = "rest", request_type=model_service.UpdateTunedModelRequest +): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'tuned_model': {'name': 'tunedModels/sample1'}} - request_init["tuned_model"] = {'tuned_model_source': {'tuned_model': 'tuned_model_value', 'base_model': 'base_model_value'}, 'base_model': 'base_model_value', 'name': 'tunedModels/sample1', 'display_name': 'display_name_value', 'description': 'description_value', 'temperature': 0.1198, 'top_p': 0.546, 'top_k': 541, 'state': 1, 'create_time': {'seconds': 751, 'nanos': 543}, 'update_time': {}, 'tuning_task': {'start_time': {}, 'complete_time': {}, 'snapshots': [{'step': 444, 'epoch': 527, 'mean_loss': 0.961, 'compute_time': {}}], 'training_data': {'examples': {'examples': [{'text_input': 'text_input_value', 'output': 'output_value'}]}}, 'hyperparameters': {'epoch_count': 1175, 'batch_size': 1052, 'learning_rate': 0.1371}}} + request_init = {"tuned_model": {"name": "tunedModels/sample1"}} + request_init["tuned_model"] = { + "tuned_model_source": { + "tuned_model": "tuned_model_value", + "base_model": "base_model_value", + }, + "base_model": "base_model_value", + "name": "tunedModels/sample1", + "display_name": "display_name_value", + "description": "description_value", + "temperature": 0.1198, + "top_p": 0.546, + "top_k": 541, + "state": 1, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "tuning_task": { + "start_time": {}, + "complete_time": {}, + "snapshots": [ + {"step": 444, "epoch": 527, "mean_loss": 0.961, "compute_time": {}} + ], + "training_data": { + "examples": { + "examples": [ + {"text_input": "text_input_value", "output": "output_value"} + ] + } + }, + "hyperparameters": { + "epoch_count": 1175, + "batch_size": 1052, + "learning_rate": 0.1371, + }, + }, + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -3825,17 +4412,21 @@ def test_update_tuned_model_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = gag_tuned_model.TunedModel() # get arguments that satisfy an http rule for this method - sample_request = {'tuned_model': {'name': 'tunedModels/sample1'}} + sample_request = {"tuned_model": {"name": "tunedModels/sample1"}} # get truthy value for each flattened field mock_args = dict( - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) mock_args.update(sample_request) @@ -3844,7 +4435,7 @@ def test_update_tuned_model_rest_flattened(): response_value.status_code = 200 pb_return_value = gag_tuned_model.TunedModel.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.update_tuned_model(**mock_args) @@ -3853,10 +4444,13 @@ def test_update_tuned_model_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{tuned_model.name=tunedModels/*}" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{tuned_model.name=tunedModels/*}" % client.transport._host, + args[1], + ) -def test_update_tuned_model_rest_flattened_error(transport: str = 'rest'): +def test_update_tuned_model_rest_flattened_error(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3867,22 +4461,28 @@ def test_update_tuned_model_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.update_tuned_model( model_service.UpdateTunedModelRequest(), - tuned_model=gag_tuned_model.TunedModel(tuned_model_source=gag_tuned_model.TunedModelSource(tuned_model='tuned_model_value')), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) def test_update_tuned_model_rest_error(): client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - model_service.DeleteTunedModelRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + model_service.DeleteTunedModelRequest, + dict, + ], +) def test_delete_tuned_model_rest(request_type): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3890,20 +4490,20 @@ def test_delete_tuned_model_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1'} + request_init = {"name": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = None # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - json_return_value = '' + json_return_value = "" - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.delete_tuned_model(request) @@ -3911,94 +4511,110 @@ def test_delete_tuned_model_rest(request_type): assert response is None -def test_delete_tuned_model_rest_required_fields(request_type=model_service.DeleteTunedModelRequest): +def test_delete_tuned_model_rest_required_fields( + request_type=model_service.DeleteTunedModelRequest, +): transport_class = transports.ModelServiceRestTransport request_init = {} request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_tuned_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = 'name_value' + jsonified_request["name"] = "name_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_tuned_model._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_tuned_model._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "name" in jsonified_request - assert jsonified_request["name"] == 'name_value' + assert jsonified_request["name"] == "name_value" client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = None # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "delete", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - json_return_value = '' + json_return_value = "" - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.delete_tuned_model(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_delete_tuned_model_rest_unset_required_fields(): - transport = transports.ModelServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.delete_tuned_model._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name", ))) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_delete_tuned_model_rest_interceptors(null_interceptor): transport = transports.ModelServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.ModelServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) client = ModelServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.ModelServiceRestInterceptor, "pre_delete_tuned_model") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_delete_tuned_model" + ) as pre: pre.assert_not_called() - pb_message = model_service.DeleteTunedModelRequest.pb(model_service.DeleteTunedModelRequest()) + pb_message = model_service.DeleteTunedModelRequest.pb( + model_service.DeleteTunedModelRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -4011,29 +4627,39 @@ def test_delete_tuned_model_rest_interceptors(null_interceptor): req.return_value.request = PreparedRequest() request = model_service.DeleteTunedModelRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - client.delete_tuned_model(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.delete_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() -def test_delete_tuned_model_rest_bad_request(transport: str = 'rest', request_type=model_service.DeleteTunedModelRequest): +def test_delete_tuned_model_rest_bad_request( + transport: str = "rest", request_type=model_service.DeleteTunedModelRequest +): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1'} + request_init = {"name": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -4049,24 +4675,24 @@ def test_delete_tuned_model_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = None # get arguments that satisfy an http rule for this method - sample_request = {'name': 'tunedModels/sample1'} + sample_request = {"name": "tunedModels/sample1"} # get truthy value for each flattened field mock_args = dict( - name='name_value', + name="name_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - json_return_value = '' - response_value._content = json_return_value.encode('UTF-8') + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.delete_tuned_model(**mock_args) @@ -4075,10 +4701,12 @@ def test_delete_tuned_model_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{name=tunedModels/*}" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{name=tunedModels/*}" % client.transport._host, args[1] + ) -def test_delete_tuned_model_rest_flattened_error(transport: str = 'rest'): +def test_delete_tuned_model_rest_flattened_error(transport: str = "rest"): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -4089,14 +4717,13 @@ def test_delete_tuned_model_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.delete_tuned_model( model_service.DeleteTunedModelRequest(), - name='name_value', + name="name_value", ) def test_delete_tuned_model_rest_error(): client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -4138,8 +4765,7 @@ def test_credentials_transport_error(): options.api_key = "api_key" with pytest.raises(ValueError): client = ModelServiceClient( - client_options=options, - credentials=ga_credentials.AnonymousCredentials() + client_options=options, credentials=ga_credentials.AnonymousCredentials() ) # It is an error to provide scopes and a transport instance. @@ -4161,6 +4787,7 @@ def test_transport_instance(): client = ModelServiceClient(transport=transport) assert client.transport is transport + def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.ModelServiceGrpcTransport( @@ -4175,28 +4802,37 @@ def test_transport_get_channel(): channel = transport.grpc_channel assert channel -@pytest.mark.parametrize("transport_class", [ - transports.ModelServiceGrpcTransport, - transports.ModelServiceGrpcAsyncIOTransport, - transports.ModelServiceRestTransport, -]) + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + transports.ModelServiceRestTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() -@pytest.mark.parametrize("transport_name", [ - "grpc", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) def test_transport_kind(transport_name): transport = ModelServiceClient.get_transport_class(transport_name)( credentials=ga_credentials.AnonymousCredentials(), ) assert transport.kind == transport_name + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = ModelServiceClient( @@ -4207,18 +4843,21 @@ def test_transport_grpc_default(): transports.ModelServiceGrpcTransport, ) + def test_model_service_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.ModelServiceTransport( credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_model_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport.__init__') as Transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.ModelServiceTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -4227,13 +4866,13 @@ def test_model_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'get_model', - 'list_models', - 'get_tuned_model', - 'list_tuned_models', - 'create_tuned_model', - 'update_tuned_model', - 'delete_tuned_model', + "get_model", + "list_models", + "get_tuned_model", + "list_tuned_models", + "create_tuned_model", + "update_tuned_model", + "delete_tuned_model", ) for method in methods: with pytest.raises(NotImplementedError): @@ -4249,7 +4888,7 @@ def test_model_service_base_transport(): # Catch all for all remaining methods and properties remainder = [ - 'kind', + "kind", ] for r in remainder: with pytest.raises(NotImplementedError): @@ -4258,24 +4897,30 @@ def test_model_service_base_transport(): def test_model_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport( credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", + load_creds.assert_called_once_with( + "credentials.json", scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id="octopus", ) def test_model_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1beta3.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.ModelServiceTransport() @@ -4284,13 +4929,12 @@ def test_model_service_base_transport_with_adc(): def test_model_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) ModelServiceClient() adc.assert_called_once_with( scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id=None, ) @@ -4305,7 +4949,7 @@ def test_model_service_auth_adc(): def test_model_service_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( @@ -4324,47 +4968,45 @@ def test_model_service_transport_auth_adc(transport_class): ], ) def test_model_service_transport_auth_gdch_credentials(transport_class): - host = 'https://language.com' - api_audience_tests = [None, 'https://language2.com'] - api_audience_expect = [host, 'https://language2.com'] + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] for t, e in zip(api_audience_tests, api_audience_expect): - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: gdch_mock = mock.MagicMock() - type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) adc.return_value = (gdch_mock, None) transport_class(host=host, api_audience=t) - gdch_mock.with_gdch_audience.assert_called_once_with( - e - ) + gdch_mock.with_gdch_audience.assert_called_once_with(e) @pytest.mark.parametrize( "transport_class,grpc_helpers", [ (transports.ModelServiceGrpcTransport, grpc_helpers), - (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async) + (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) def test_model_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: creds = ga_credentials.AnonymousCredentials() adc.return_value = (creds, None) - transport_class( - quota_project_id="octopus", - scopes=["1", "2"] - ) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) create_channel.assert_called_with( "generativelanguage.googleapis.com:443", credentials=creds, credentials_file=None, quota_project_id="octopus", - default_scopes=( -), + default_scopes=(), scopes=["1", "2"], default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -4375,10 +5017,11 @@ def test_model_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = ga_credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -4387,7 +5030,7 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", @@ -4408,20 +5051,21 @@ def test_model_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) + def test_model_service_http_transport_client_cert_source_for_mtls(): cred = ga_credentials.AnonymousCredentials() - with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: - transports.ModelServiceRestTransport ( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ModelServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback ) mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) @@ -4429,7 +5073,7 @@ def test_model_service_http_transport_client_cert_source_for_mtls(): def test_model_service_rest_lro_client(): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) transport = client.transport @@ -4443,43 +5087,58 @@ def test_model_service_rest_lro_client(): assert transport.operations_client is transport.operations_client -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_model_service_host_no_port(transport_name): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), - transport=transport_name, + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_model_service_host_with_port(transport_name): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:8000' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com:8000' + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" ) -@pytest.mark.parametrize("transport_name", [ - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) def test_model_service_client_transport_session_collision(transport_name): creds1 = ga_credentials.AnonymousCredentials() creds2 = ga_credentials.AnonymousCredentials() @@ -4512,8 +5171,10 @@ def test_model_service_client_transport_session_collision(transport_name): session1 = client1.transport.delete_tuned_model._session session2 = client2.transport.delete_tuned_model._session assert session1 != session2 + + def test_model_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( @@ -4526,7 +5187,7 @@ def test_model_service_grpc_transport_channel(): def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( @@ -4540,12 +5201,17 @@ def test_model_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -4554,7 +5220,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( cred = ga_credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -4584,17 +5250,20 @@ def test_model_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport]) -def test_model_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -4625,7 +5294,7 @@ def test_model_service_transport_channel_mtls_with_adc( def test_model_service_grpc_lro_client(): client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) transport = client.transport @@ -4642,7 +5311,7 @@ def test_model_service_grpc_lro_client(): def test_model_service_grpc_lro_async_client(): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc_asyncio', + transport="grpc_asyncio", ) transport = client.transport @@ -4658,7 +5327,9 @@ def test_model_service_grpc_lro_async_client(): def test_model_path(): model = "squid" - expected = "models/{model}".format(model=model, ) + expected = "models/{model}".format( + model=model, + ) actual = ModelServiceClient.model_path(model) assert expected == actual @@ -4673,9 +5344,12 @@ def test_parse_model_path(): actual = ModelServiceClient.parse_model_path(path) assert expected == actual + def test_tuned_model_path(): tuned_model = "whelk" - expected = "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + expected = "tunedModels/{tuned_model}".format( + tuned_model=tuned_model, + ) actual = ModelServiceClient.tuned_model_path(tuned_model) assert expected == actual @@ -4690,9 +5364,12 @@ def test_parse_tuned_model_path(): actual = ModelServiceClient.parse_tuned_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "oyster" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = ModelServiceClient.common_billing_account_path(billing_account) assert expected == actual @@ -4707,9 +5384,12 @@ def test_parse_common_billing_account_path(): actual = ModelServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "cuttlefish" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format( + folder=folder, + ) actual = ModelServiceClient.common_folder_path(folder) assert expected == actual @@ -4724,9 +5404,12 @@ def test_parse_common_folder_path(): actual = ModelServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "winkle" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format( + organization=organization, + ) actual = ModelServiceClient.common_organization_path(organization) assert expected == actual @@ -4741,9 +5424,12 @@ def test_parse_common_organization_path(): actual = ModelServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "scallop" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format( + project=project, + ) actual = ModelServiceClient.common_project_path(project) assert expected == actual @@ -4758,10 +5444,14 @@ def test_parse_common_project_path(): actual = ModelServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "squid" location = "clam" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) actual = ModelServiceClient.common_location_path(project, location) assert expected == actual @@ -4781,14 +5471,18 @@ def test_parse_common_location_path(): def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: client = ModelServiceClient( credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.ModelServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = ModelServiceClient.get_transport_class() transport = transport_class( credentials=ga_credentials.AnonymousCredentials(), @@ -4796,13 +5490,16 @@ def test_client_with_default_client_info(): ) prep.assert_called_once_with(client_info) + @pytest.mark.asyncio async def test_transport_close_async(): client = ModelServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: async with client: close.assert_not_called() close.assert_called_once() @@ -4816,23 +5513,24 @@ def test_transport_close(): for transport, close_name in transports.items(): client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) - with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: with client: close.assert_not_called() close.assert_called_once() + def test_client_ctx(): transports = [ - 'rest', - 'grpc', + "rest", + "grpc", ] for transport in transports: client = ModelServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -4841,10 +5539,14 @@ def test_client_ctx(): pass close.assert_called() -@pytest.mark.parametrize("client_class,transport_class", [ - (ModelServiceClient, transports.ModelServiceGrpcTransport), - (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), -]) + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), + ], +) def test_api_key_credentials(client_class, transport_class): with mock.patch.object( google.auth._default, "get_api_key_credentials", create=True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py similarity index 69% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py rename to packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py index aa9954df98a8..108f7e7829d7 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_permission_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import os + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -21,39 +22,37 @@ except ImportError: # pragma: NO COVER import mock -import grpc -from grpc.experimental import aio from collections.abc import Iterable -from google.protobuf import json_format import json import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format -from google.ai.generativelanguage_v1beta3.services.permission_service import PermissionServiceAsyncClient -from google.ai.generativelanguage_v1beta3.services.permission_service import PermissionServiceClient -from google.ai.generativelanguage_v1beta3.services.permission_service import pagers -from google.ai.generativelanguage_v1beta3.services.permission_service import transports -from google.ai.generativelanguage_v1beta3.types import permission -from google.ai.generativelanguage_v1beta3.types import permission as gag_permission -from google.ai.generativelanguage_v1beta3.types import permission_service +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template from google.api_core import client_options from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import path_template +import google.auth from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError -from google.longrunning import operations_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import field_mask_pb2 # type: ignore -import google.auth +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +from google.ai.generativelanguage_v1beta3.services.permission_service import ( + PermissionServiceAsyncClient, + PermissionServiceClient, + pagers, + transports, +) +from google.ai.generativelanguage_v1beta3.types import permission as gag_permission +from google.ai.generativelanguage_v1beta3.types import permission +from google.ai.generativelanguage_v1beta3.types import permission_service def client_cert_source_callback(): @@ -64,7 +63,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -75,21 +78,43 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert PermissionServiceClient._get_default_mtls_endpoint(None) is None - assert PermissionServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert PermissionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert PermissionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert PermissionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint - assert PermissionServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi - - -@pytest.mark.parametrize("client_class,transport_name", [ - (PermissionServiceClient, "grpc"), - (PermissionServiceAsyncClient, "grpc_asyncio"), - (PermissionServiceClient, "rest"), -]) -def test_permission_service_client_from_service_account_info(client_class, transport_name): + assert ( + PermissionServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (PermissionServiceClient, "grpc"), + (PermissionServiceAsyncClient, "grpc_asyncio"), + (PermissionServiceClient, "rest"), + ], +) +def test_permission_service_client_from_service_account_info( + client_class, transport_name +): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info, transport=transport_name) @@ -97,52 +122,70 @@ def test_permission_service_client_from_service_account_info(client_class, trans assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_class,transport_name", [ - (transports.PermissionServiceGrpcTransport, "grpc"), - (transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.PermissionServiceRestTransport, "rest"), -]) -def test_permission_service_client_service_account_always_use_jwt(transport_class, transport_name): - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.PermissionServiceGrpcTransport, "grpc"), + (transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.PermissionServiceRestTransport, "rest"), + ], +) +def test_permission_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=True) use_jwt.assert_called_once_with(True) - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=False) use_jwt.assert_not_called() -@pytest.mark.parametrize("client_class,transport_name", [ - (PermissionServiceClient, "grpc"), - (PermissionServiceAsyncClient, "grpc_asyncio"), - (PermissionServiceClient, "rest"), -]) -def test_permission_service_client_from_service_account_file(client_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (PermissionServiceClient, "grpc"), + (PermissionServiceAsyncClient, "grpc_asyncio"), + (PermissionServiceClient, "rest"), + ], +) +def test_permission_service_client_from_service_account_file( + client_class, transport_name +): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) - client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) @@ -158,30 +201,45 @@ def test_permission_service_client_get_transport_class(): assert transport == transports.PermissionServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), - (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), -]) -@mock.patch.object(PermissionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceClient)) -@mock.patch.object(PermissionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceAsyncClient)) -def test_permission_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + PermissionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceAsyncClient), +) +def test_permission_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(PermissionServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ) + with mock.patch.object(PermissionServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(PermissionServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(PermissionServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( @@ -199,7 +257,7 @@ def test_permission_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -217,7 +275,7 @@ def test_permission_service_client_client_options(client_class, transport_class, # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -239,13 +297,15 @@ def test_permission_service_client_client_options(client_class, transport_class, client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -260,8 +320,10 @@ def test_permission_service_client_client_options(client_class, transport_class, api_audience=None, ) # Check the case api_endpoint is provided - options = client_options.ClientOptions(api_audience="https://language.googleapis.com") - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -273,29 +335,77 @@ def test_permission_service_client_client_options(client_class, transport_class, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, always_use_jwt_access=True, - api_audience="https://language.googleapis.com" - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", "true"), - (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", "false"), - (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest", "true"), - (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest", "false"), -]) -@mock.patch.object(PermissionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceClient)) -@mock.patch.object(PermissionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceAsyncClient)) + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + "true", + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + "false", + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ( + PermissionServiceClient, + transports.PermissionServiceRestTransport, + "rest", + "true", + ), + ( + PermissionServiceClient, + transports.PermissionServiceRestTransport, + "rest", + "false", + ), + ], +) +@mock.patch.object( + PermissionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_permission_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_permission_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) @@ -320,10 +430,18 @@ def test_permission_service_client_mtls_env_auto(client_class, transport_class, # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -346,9 +464,14 @@ def test_permission_service_client_mtls_env_auto(client_class, transport_class, ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -364,19 +487,31 @@ def test_permission_service_client_mtls_env_auto(client_class, transport_class, ) -@pytest.mark.parametrize("client_class", [ - PermissionServiceClient, PermissionServiceAsyncClient -]) -@mock.patch.object(PermissionServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceClient)) -@mock.patch.object(PermissionServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(PermissionServiceAsyncClient)) +@pytest.mark.parametrize( + "client_class", [PermissionServiceClient, PermissionServiceAsyncClient] +) +@mock.patch.object( + PermissionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceAsyncClient), +) def test_permission_service_client_get_mtls_endpoint_and_cert_source(client_class): mock_client_cert_source = mock.Mock() # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source == mock_client_cert_source @@ -384,8 +519,12 @@ def test_permission_service_client_get_mtls_endpoint_and_cert_source(client_clas with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): mock_client_cert_source = mock.Mock() mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source is None @@ -403,31 +542,52 @@ def test_permission_service_client_get_mtls_endpoint_and_cert_source(client_clas # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_ENDPOINT assert cert_source is None # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), - (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), -]) -def test_permission_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), + ], +) +def test_permission_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. options = client_options.ClientOptions( scopes=["1", "2"], ) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -442,18 +602,37 @@ def test_permission_service_client_client_options_scopes(client_class, transport api_audience=None, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", grpc_helpers), - (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), - (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest", None), -]) -def test_permission_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ( + PermissionServiceClient, + transports.PermissionServiceRestTransport, + "rest", + None, + ), + ], +) +def test_permission_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -468,11 +647,14 @@ def test_permission_service_client_client_options_credentials_file(client_class, api_audience=None, ) + def test_permission_service_client_client_options_from_dict(): - with mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None client = PermissionServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} + client_options={"api_endpoint": "squid.clam.whelk"} ) grpc_transport.assert_called_once_with( credentials=None, @@ -487,17 +669,30 @@ def test_permission_service_client_client_options_from_dict(): ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc", grpc_helpers), - (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), -]) -def test_permission_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_permission_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -530,8 +725,7 @@ def test_permission_service_client_create_channel_credentials_file(client_class, credentials=file_creds, credentials_file=None, quota_project_id=None, - default_scopes=( -), + default_scopes=(), scopes=None, default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -542,11 +736,14 @@ def test_permission_service_client_create_channel_credentials_file(client_class, ) -@pytest.mark.parametrize("request_type", [ - permission_service.CreatePermissionRequest, - dict, -]) -def test_create_permission(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + permission_service.CreatePermissionRequest, + dict, + ], +) +def test_create_permission(request_type, transport: str = "grpc"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -558,13 +755,13 @@ def test_create_permission(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_permission), - '__call__') as call: + type(client.transport.create_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_permission.Permission( - name='name_value', + name="name_value", grantee_type=gag_permission.Permission.GranteeType.USER, - email_address='email_address_value', + email_address="email_address_value", role=gag_permission.Permission.Role.OWNER, ) response = client.create_permission(request) @@ -576,9 +773,9 @@ def test_create_permission(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, gag_permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == gag_permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == gag_permission.Permission.Role.OWNER @@ -587,20 +784,24 @@ def test_create_permission_empty_call(): # i.e. request == None and no flattened fields passed, work. client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_permission), - '__call__') as call: + type(client.transport.create_permission), "__call__" + ) as call: client.create_permission() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == permission_service.CreatePermissionRequest() + @pytest.mark.asyncio -async def test_create_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.CreatePermissionRequest): +async def test_create_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.CreatePermissionRequest, +): client = PermissionServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -612,15 +813,17 @@ async def test_create_permission_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_permission), - '__call__') as call: + type(client.transport.create_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission( - name='name_value', - grantee_type=gag_permission.Permission.GranteeType.USER, - email_address='email_address_value', - role=gag_permission.Permission.Role.OWNER, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + ) response = await client.create_permission(request) # Establish that the underlying gRPC stub method was called. @@ -630,9 +833,9 @@ async def test_create_permission_async(transport: str = 'grpc_asyncio', request_ # Establish that the response is the type that we expect. assert isinstance(response, gag_permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == gag_permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == gag_permission.Permission.Role.OWNER @@ -650,12 +853,12 @@ def test_create_permission_field_headers(): # a field header. Set these to a non-empty value. request = permission_service.CreatePermissionRequest() - request.parent = 'parent_value' + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_permission), - '__call__') as call: + type(client.transport.create_permission), "__call__" + ) as call: call.return_value = gag_permission.Permission() client.create_permission(request) @@ -667,9 +870,9 @@ def test_create_permission_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'parent=parent_value', - ) in kw['metadata'] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -682,13 +885,15 @@ async def test_create_permission_field_headers_async(): # a field header. Set these to a non-empty value. request = permission_service.CreatePermissionRequest() - request.parent = 'parent_value' + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_permission), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + type(client.transport.create_permission), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) await client.create_permission(request) # Establish that the underlying gRPC stub method was called. @@ -699,9 +904,9 @@ async def test_create_permission_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'parent=parent_value', - ) in kw['metadata'] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_create_permission_flattened(): @@ -711,15 +916,15 @@ def test_create_permission_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_permission), - '__call__') as call: + type(client.transport.create_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_permission.Permission() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.create_permission( - parent='parent_value', - permission=gag_permission.Permission(name='name_value'), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -727,10 +932,10 @@ def test_create_permission_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].parent - mock_val = 'parent_value' + mock_val = "parent_value" assert arg == mock_val arg = args[0].permission - mock_val = gag_permission.Permission(name='name_value') + mock_val = gag_permission.Permission(name="name_value") assert arg == mock_val @@ -744,10 +949,11 @@ def test_create_permission_flattened_error(): with pytest.raises(ValueError): client.create_permission( permission_service.CreatePermissionRequest(), - parent='parent_value', - permission=gag_permission.Permission(name='name_value'), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), ) + @pytest.mark.asyncio async def test_create_permission_flattened_async(): client = PermissionServiceAsyncClient( @@ -756,17 +962,19 @@ async def test_create_permission_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.create_permission), - '__call__') as call: + type(client.transport.create_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_permission.Permission() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.create_permission( - parent='parent_value', - permission=gag_permission.Permission(name='name_value'), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), ) # Establish that the underlying call was made with the expected @@ -774,12 +982,13 @@ async def test_create_permission_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].parent - mock_val = 'parent_value' + mock_val = "parent_value" assert arg == mock_val arg = args[0].permission - mock_val = gag_permission.Permission(name='name_value') + mock_val = gag_permission.Permission(name="name_value") assert arg == mock_val + @pytest.mark.asyncio async def test_create_permission_flattened_error_async(): client = PermissionServiceAsyncClient( @@ -791,16 +1000,19 @@ async def test_create_permission_flattened_error_async(): with pytest.raises(ValueError): await client.create_permission( permission_service.CreatePermissionRequest(), - parent='parent_value', - permission=gag_permission.Permission(name='name_value'), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), ) -@pytest.mark.parametrize("request_type", [ - permission_service.GetPermissionRequest, - dict, -]) -def test_get_permission(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + permission_service.GetPermissionRequest, + dict, + ], +) +def test_get_permission(request_type, transport: str = "grpc"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -811,14 +1023,12 @@ def test_get_permission(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_permission), - '__call__') as call: + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = permission.Permission( - name='name_value', + name="name_value", grantee_type=permission.Permission.GranteeType.USER, - email_address='email_address_value', + email_address="email_address_value", role=permission.Permission.Role.OWNER, ) response = client.get_permission(request) @@ -830,9 +1040,9 @@ def test_get_permission(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == permission.Permission.Role.OWNER @@ -841,20 +1051,22 @@ def test_get_permission_empty_call(): # i.e. request == None and no flattened fields passed, work. client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_permission), - '__call__') as call: + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: client.get_permission() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == permission_service.GetPermissionRequest() + @pytest.mark.asyncio -async def test_get_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.GetPermissionRequest): +async def test_get_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.GetPermissionRequest, +): client = PermissionServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -865,16 +1077,16 @@ async def test_get_permission_async(transport: str = 'grpc_asyncio', request_typ request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_permission), - '__call__') as call: + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(permission.Permission( - name='name_value', - grantee_type=permission.Permission.GranteeType.USER, - email_address='email_address_value', - role=permission.Permission.Role.OWNER, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission.Permission( + name="name_value", + grantee_type=permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=permission.Permission.Role.OWNER, + ) + ) response = await client.get_permission(request) # Establish that the underlying gRPC stub method was called. @@ -884,9 +1096,9 @@ async def test_get_permission_async(transport: str = 'grpc_asyncio', request_typ # Establish that the response is the type that we expect. assert isinstance(response, permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == permission.Permission.Role.OWNER @@ -904,12 +1116,10 @@ def test_get_permission_field_headers(): # a field header. Set these to a non-empty value. request = permission_service.GetPermissionRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_permission), - '__call__') as call: + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: call.return_value = permission.Permission() client.get_permission(request) @@ -921,9 +1131,9 @@ def test_get_permission_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -936,13 +1146,13 @@ async def test_get_permission_field_headers_async(): # a field header. Set these to a non-empty value. request = permission_service.GetPermissionRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_permission), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission.Permission()) + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission.Permission() + ) await client.get_permission(request) # Establish that the underlying gRPC stub method was called. @@ -953,9 +1163,9 @@ async def test_get_permission_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_get_permission_flattened(): @@ -964,15 +1174,13 @@ def test_get_permission_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_permission), - '__call__') as call: + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = permission.Permission() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.get_permission( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -980,7 +1188,7 @@ def test_get_permission_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val @@ -994,9 +1202,10 @@ def test_get_permission_flattened_error(): with pytest.raises(ValueError): client.get_permission( permission_service.GetPermissionRequest(), - name='name_value', + name="name_value", ) + @pytest.mark.asyncio async def test_get_permission_flattened_async(): client = PermissionServiceAsyncClient( @@ -1004,17 +1213,17 @@ async def test_get_permission_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.get_permission), - '__call__') as call: + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = permission.Permission() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission.Permission()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission.Permission() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.get_permission( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -1022,9 +1231,10 @@ async def test_get_permission_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val + @pytest.mark.asyncio async def test_get_permission_flattened_error_async(): client = PermissionServiceAsyncClient( @@ -1036,15 +1246,18 @@ async def test_get_permission_flattened_error_async(): with pytest.raises(ValueError): await client.get_permission( permission_service.GetPermissionRequest(), - name='name_value', + name="name_value", ) -@pytest.mark.parametrize("request_type", [ - permission_service.ListPermissionsRequest, - dict, -]) -def test_list_permissions(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + permission_service.ListPermissionsRequest, + dict, + ], +) +def test_list_permissions(request_type, transport: str = "grpc"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1055,12 +1268,10 @@ def test_list_permissions(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = permission_service.ListPermissionsResponse( - next_page_token='next_page_token_value', + next_page_token="next_page_token_value", ) response = client.list_permissions(request) @@ -1071,7 +1282,7 @@ def test_list_permissions(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListPermissionsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" def test_list_permissions_empty_call(): @@ -1079,20 +1290,22 @@ def test_list_permissions_empty_call(): # i.e. request == None and no flattened fields passed, work. client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: client.list_permissions() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == permission_service.ListPermissionsRequest() + @pytest.mark.asyncio -async def test_list_permissions_async(transport: str = 'grpc_asyncio', request_type=permission_service.ListPermissionsRequest): +async def test_list_permissions_async( + transport: str = "grpc_asyncio", + request_type=permission_service.ListPermissionsRequest, +): client = PermissionServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1103,13 +1316,13 @@ async def test_list_permissions_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(permission_service.ListPermissionsResponse( - next_page_token='next_page_token_value', - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.ListPermissionsResponse( + next_page_token="next_page_token_value", + ) + ) response = await client.list_permissions(request) # Establish that the underlying gRPC stub method was called. @@ -1119,7 +1332,7 @@ async def test_list_permissions_async(transport: str = 'grpc_asyncio', request_t # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListPermissionsAsyncPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" @pytest.mark.asyncio @@ -1136,12 +1349,10 @@ def test_list_permissions_field_headers(): # a field header. Set these to a non-empty value. request = permission_service.ListPermissionsRequest() - request.parent = 'parent_value' + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: call.return_value = permission_service.ListPermissionsResponse() client.list_permissions(request) @@ -1153,9 +1364,9 @@ def test_list_permissions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'parent=parent_value', - ) in kw['metadata'] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1168,13 +1379,13 @@ async def test_list_permissions_field_headers_async(): # a field header. Set these to a non-empty value. request = permission_service.ListPermissionsRequest() - request.parent = 'parent_value' + request.parent = "parent_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission_service.ListPermissionsResponse()) + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.ListPermissionsResponse() + ) await client.list_permissions(request) # Establish that the underlying gRPC stub method was called. @@ -1185,9 +1396,9 @@ async def test_list_permissions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'parent=parent_value', - ) in kw['metadata'] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_permissions_flattened(): @@ -1196,15 +1407,13 @@ def test_list_permissions_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = permission_service.ListPermissionsResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.list_permissions( - parent='parent_value', + parent="parent_value", ) # Establish that the underlying call was made with the expected @@ -1212,7 +1421,7 @@ def test_list_permissions_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].parent - mock_val = 'parent_value' + mock_val = "parent_value" assert arg == mock_val @@ -1226,9 +1435,10 @@ def test_list_permissions_flattened_error(): with pytest.raises(ValueError): client.list_permissions( permission_service.ListPermissionsRequest(), - parent='parent_value', + parent="parent_value", ) + @pytest.mark.asyncio async def test_list_permissions_flattened_async(): client = PermissionServiceAsyncClient( @@ -1236,17 +1446,17 @@ async def test_list_permissions_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = permission_service.ListPermissionsResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission_service.ListPermissionsResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.ListPermissionsResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.list_permissions( - parent='parent_value', + parent="parent_value", ) # Establish that the underlying call was made with the expected @@ -1254,9 +1464,10 @@ async def test_list_permissions_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].parent - mock_val = 'parent_value' + mock_val = "parent_value" assert arg == mock_val + @pytest.mark.asyncio async def test_list_permissions_flattened_error_async(): client = PermissionServiceAsyncClient( @@ -1268,7 +1479,7 @@ async def test_list_permissions_flattened_error_async(): with pytest.raises(ValueError): await client.list_permissions( permission_service.ListPermissionsRequest(), - parent='parent_value', + parent="parent_value", ) @@ -1279,9 +1490,7 @@ def test_list_permissions_pager(transport_name: str = "grpc"): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( permission_service.ListPermissionsResponse( @@ -1290,17 +1499,17 @@ def test_list_permissions_pager(transport_name: str = "grpc"): permission.Permission(), permission.Permission(), ], - next_page_token='abc', + next_page_token="abc", ), permission_service.ListPermissionsResponse( permissions=[], - next_page_token='def', + next_page_token="def", ), permission_service.ListPermissionsResponse( permissions=[ permission.Permission(), ], - next_page_token='ghi', + next_page_token="ghi", ), permission_service.ListPermissionsResponse( permissions=[ @@ -1313,9 +1522,7 @@ def test_list_permissions_pager(transport_name: str = "grpc"): metadata = () metadata = tuple(metadata) + ( - gapic_v1.routing_header.to_grpc_metadata(( - ('parent', ''), - )), + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), ) pager = client.list_permissions(request={}) @@ -1323,8 +1530,9 @@ def test_list_permissions_pager(transport_name: str = "grpc"): results = list(pager) assert len(results) == 6 - assert all(isinstance(i, permission.Permission) - for i in results) + assert all(isinstance(i, permission.Permission) for i in results) + + def test_list_permissions_pages(transport_name: str = "grpc"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials, @@ -1332,9 +1540,7 @@ def test_list_permissions_pages(transport_name: str = "grpc"): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.list_permissions), - '__call__') as call: + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: # Set the response to a series of pages. call.side_effect = ( permission_service.ListPermissionsResponse( @@ -1343,17 +1549,17 @@ def test_list_permissions_pages(transport_name: str = "grpc"): permission.Permission(), permission.Permission(), ], - next_page_token='abc', + next_page_token="abc", ), permission_service.ListPermissionsResponse( permissions=[], - next_page_token='def', + next_page_token="def", ), permission_service.ListPermissionsResponse( permissions=[ permission.Permission(), ], - next_page_token='ghi', + next_page_token="ghi", ), permission_service.ListPermissionsResponse( permissions=[ @@ -1364,9 +1570,10 @@ def test_list_permissions_pages(transport_name: str = "grpc"): RuntimeError, ) pages = list(client.list_permissions(request={}).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token + @pytest.mark.asyncio async def test_list_permissions_async_pager(): client = PermissionServiceAsyncClient( @@ -1375,8 +1582,8 @@ async def test_list_permissions_async_pager(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_permissions), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_permissions), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( permission_service.ListPermissionsResponse( @@ -1385,17 +1592,17 @@ async def test_list_permissions_async_pager(): permission.Permission(), permission.Permission(), ], - next_page_token='abc', + next_page_token="abc", ), permission_service.ListPermissionsResponse( permissions=[], - next_page_token='def', + next_page_token="def", ), permission_service.ListPermissionsResponse( permissions=[ permission.Permission(), ], - next_page_token='ghi', + next_page_token="ghi", ), permission_service.ListPermissionsResponse( permissions=[ @@ -1405,15 +1612,16 @@ async def test_list_permissions_async_pager(): ), RuntimeError, ) - async_pager = await client.list_permissions(request={},) - assert async_pager.next_page_token == 'abc' + async_pager = await client.list_permissions( + request={}, + ) + assert async_pager.next_page_token == "abc" responses = [] - async for response in async_pager: # pragma: no branch + async for response in async_pager: # pragma: no branch responses.append(response) assert len(responses) == 6 - assert all(isinstance(i, permission.Permission) - for i in responses) + assert all(isinstance(i, permission.Permission) for i in responses) @pytest.mark.asyncio @@ -1424,8 +1632,8 @@ async def test_list_permissions_async_pages(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.list_permissions), - '__call__', new_callable=mock.AsyncMock) as call: + type(client.transport.list_permissions), "__call__", new_callable=mock.AsyncMock + ) as call: # Set the response to a series of pages. call.side_effect = ( permission_service.ListPermissionsResponse( @@ -1434,17 +1642,17 @@ async def test_list_permissions_async_pages(): permission.Permission(), permission.Permission(), ], - next_page_token='abc', + next_page_token="abc", ), permission_service.ListPermissionsResponse( permissions=[], - next_page_token='def', + next_page_token="def", ), permission_service.ListPermissionsResponse( permissions=[ permission.Permission(), ], - next_page_token='ghi', + next_page_token="ghi", ), permission_service.ListPermissionsResponse( permissions=[ @@ -1457,18 +1665,22 @@ async def test_list_permissions_async_pages(): pages = [] # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 - async for page_ in ( # pragma: no branch + async for page_ in ( # pragma: no branch await client.list_permissions(request={}) ).pages: pages.append(page_) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -@pytest.mark.parametrize("request_type", [ - permission_service.UpdatePermissionRequest, - dict, -]) -def test_update_permission(request_type, transport: str = 'grpc'): + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.UpdatePermissionRequest, + dict, + ], +) +def test_update_permission(request_type, transport: str = "grpc"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1480,13 +1692,13 @@ def test_update_permission(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_permission), - '__call__') as call: + type(client.transport.update_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_permission.Permission( - name='name_value', + name="name_value", grantee_type=gag_permission.Permission.GranteeType.USER, - email_address='email_address_value', + email_address="email_address_value", role=gag_permission.Permission.Role.OWNER, ) response = client.update_permission(request) @@ -1498,9 +1710,9 @@ def test_update_permission(request_type, transport: str = 'grpc'): # Establish that the response is the type that we expect. assert isinstance(response, gag_permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == gag_permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == gag_permission.Permission.Role.OWNER @@ -1509,20 +1721,24 @@ def test_update_permission_empty_call(): # i.e. request == None and no flattened fields passed, work. client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_permission), - '__call__') as call: + type(client.transport.update_permission), "__call__" + ) as call: client.update_permission() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == permission_service.UpdatePermissionRequest() + @pytest.mark.asyncio -async def test_update_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.UpdatePermissionRequest): +async def test_update_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.UpdatePermissionRequest, +): client = PermissionServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1534,15 +1750,17 @@ async def test_update_permission_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_permission), - '__call__') as call: + type(client.transport.update_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission( - name='name_value', - grantee_type=gag_permission.Permission.GranteeType.USER, - email_address='email_address_value', - role=gag_permission.Permission.Role.OWNER, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + ) response = await client.update_permission(request) # Establish that the underlying gRPC stub method was called. @@ -1552,9 +1770,9 @@ async def test_update_permission_async(transport: str = 'grpc_asyncio', request_ # Establish that the response is the type that we expect. assert isinstance(response, gag_permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == gag_permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == gag_permission.Permission.Role.OWNER @@ -1572,12 +1790,12 @@ def test_update_permission_field_headers(): # a field header. Set these to a non-empty value. request = permission_service.UpdatePermissionRequest() - request.permission.name = 'name_value' + request.permission.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_permission), - '__call__') as call: + type(client.transport.update_permission), "__call__" + ) as call: call.return_value = gag_permission.Permission() client.update_permission(request) @@ -1589,9 +1807,9 @@ def test_update_permission_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'permission.name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "permission.name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1604,13 +1822,15 @@ async def test_update_permission_field_headers_async(): # a field header. Set these to a non-empty value. request = permission_service.UpdatePermissionRequest() - request.permission.name = 'name_value' + request.permission.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_permission), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + type(client.transport.update_permission), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) await client.update_permission(request) # Establish that the underlying gRPC stub method was called. @@ -1621,9 +1841,9 @@ async def test_update_permission_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'permission.name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "permission.name=name_value", + ) in kw["metadata"] def test_update_permission_flattened(): @@ -1633,15 +1853,15 @@ def test_update_permission_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_permission), - '__call__') as call: + type(client.transport.update_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_permission.Permission() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.update_permission( - permission=gag_permission.Permission(name='name_value'), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1649,10 +1869,10 @@ def test_update_permission_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].permission - mock_val = gag_permission.Permission(name='name_value') + mock_val = gag_permission.Permission(name="name_value") assert arg == mock_val arg = args[0].update_mask - mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) assert arg == mock_val @@ -1666,10 +1886,11 @@ def test_update_permission_flattened_error(): with pytest.raises(ValueError): client.update_permission( permission_service.UpdatePermissionRequest(), - permission=gag_permission.Permission(name='name_value'), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) + @pytest.mark.asyncio async def test_update_permission_flattened_async(): client = PermissionServiceAsyncClient( @@ -1678,17 +1899,19 @@ async def test_update_permission_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.update_permission), - '__call__') as call: + type(client.transport.update_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = gag_permission.Permission() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gag_permission.Permission()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.update_permission( - permission=gag_permission.Permission(name='name_value'), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) # Establish that the underlying call was made with the expected @@ -1696,12 +1919,13 @@ async def test_update_permission_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].permission - mock_val = gag_permission.Permission(name='name_value') + mock_val = gag_permission.Permission(name="name_value") assert arg == mock_val arg = args[0].update_mask - mock_val = field_mask_pb2.FieldMask(paths=['paths_value']) + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) assert arg == mock_val + @pytest.mark.asyncio async def test_update_permission_flattened_error_async(): client = PermissionServiceAsyncClient( @@ -1713,16 +1937,19 @@ async def test_update_permission_flattened_error_async(): with pytest.raises(ValueError): await client.update_permission( permission_service.UpdatePermissionRequest(), - permission=gag_permission.Permission(name='name_value'), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) -@pytest.mark.parametrize("request_type", [ - permission_service.DeletePermissionRequest, - dict, -]) -def test_delete_permission(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + permission_service.DeletePermissionRequest, + dict, + ], +) +def test_delete_permission(request_type, transport: str = "grpc"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1734,8 +1961,8 @@ def test_delete_permission(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_permission), - '__call__') as call: + type(client.transport.delete_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None response = client.delete_permission(request) @@ -1754,20 +1981,24 @@ def test_delete_permission_empty_call(): # i.e. request == None and no flattened fields passed, work. client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_permission), - '__call__') as call: + type(client.transport.delete_permission), "__call__" + ) as call: client.delete_permission() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == permission_service.DeletePermissionRequest() + @pytest.mark.asyncio -async def test_delete_permission_async(transport: str = 'grpc_asyncio', request_type=permission_service.DeletePermissionRequest): +async def test_delete_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.DeletePermissionRequest, +): client = PermissionServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1779,8 +2010,8 @@ async def test_delete_permission_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_permission), - '__call__') as call: + type(client.transport.delete_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) response = await client.delete_permission(request) @@ -1808,12 +2039,12 @@ def test_delete_permission_field_headers(): # a field header. Set these to a non-empty value. request = permission_service.DeletePermissionRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_permission), - '__call__') as call: + type(client.transport.delete_permission), "__call__" + ) as call: call.return_value = None client.delete_permission(request) @@ -1825,9 +2056,9 @@ def test_delete_permission_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1840,12 +2071,12 @@ async def test_delete_permission_field_headers_async(): # a field header. Set these to a non-empty value. request = permission_service.DeletePermissionRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_permission), - '__call__') as call: + type(client.transport.delete_permission), "__call__" + ) as call: call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) await client.delete_permission(request) @@ -1857,9 +2088,9 @@ async def test_delete_permission_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_delete_permission_flattened(): @@ -1869,14 +2100,14 @@ def test_delete_permission_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_permission), - '__call__') as call: + type(client.transport.delete_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.delete_permission( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -1884,7 +2115,7 @@ def test_delete_permission_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val @@ -1898,9 +2129,10 @@ def test_delete_permission_flattened_error(): with pytest.raises(ValueError): client.delete_permission( permission_service.DeletePermissionRequest(), - name='name_value', + name="name_value", ) + @pytest.mark.asyncio async def test_delete_permission_flattened_async(): client = PermissionServiceAsyncClient( @@ -1909,8 +2141,8 @@ async def test_delete_permission_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.delete_permission), - '__call__') as call: + type(client.transport.delete_permission), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = None @@ -1918,7 +2150,7 @@ async def test_delete_permission_flattened_async(): # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.delete_permission( - name='name_value', + name="name_value", ) # Establish that the underlying call was made with the expected @@ -1926,9 +2158,10 @@ async def test_delete_permission_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].name - mock_val = 'name_value' + mock_val = "name_value" assert arg == mock_val + @pytest.mark.asyncio async def test_delete_permission_flattened_error_async(): client = PermissionServiceAsyncClient( @@ -1940,15 +2173,18 @@ async def test_delete_permission_flattened_error_async(): with pytest.raises(ValueError): await client.delete_permission( permission_service.DeletePermissionRequest(), - name='name_value', + name="name_value", ) -@pytest.mark.parametrize("request_type", [ - permission_service.TransferOwnershipRequest, - dict, -]) -def test_transfer_ownership(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + permission_service.TransferOwnershipRequest, + dict, + ], +) +def test_transfer_ownership(request_type, transport: str = "grpc"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1960,11 +2196,10 @@ def test_transfer_ownership(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.transfer_ownership), - '__call__') as call: + type(client.transport.transfer_ownership), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value = permission_service.TransferOwnershipResponse( - ) + call.return_value = permission_service.TransferOwnershipResponse() response = client.transfer_ownership(request) # Establish that the underlying gRPC stub method was called. @@ -1981,20 +2216,24 @@ def test_transfer_ownership_empty_call(): # i.e. request == None and no flattened fields passed, work. client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.transfer_ownership), - '__call__') as call: + type(client.transport.transfer_ownership), "__call__" + ) as call: client.transfer_ownership() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == permission_service.TransferOwnershipRequest() + @pytest.mark.asyncio -async def test_transfer_ownership_async(transport: str = 'grpc_asyncio', request_type=permission_service.TransferOwnershipRequest): +async def test_transfer_ownership_async( + transport: str = "grpc_asyncio", + request_type=permission_service.TransferOwnershipRequest, +): client = PermissionServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2006,11 +2245,12 @@ async def test_transfer_ownership_async(transport: str = 'grpc_asyncio', request # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.transfer_ownership), - '__call__') as call: + type(client.transport.transfer_ownership), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(permission_service.TransferOwnershipResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.TransferOwnershipResponse() + ) response = await client.transfer_ownership(request) # Establish that the underlying gRPC stub method was called. @@ -2036,12 +2276,12 @@ def test_transfer_ownership_field_headers(): # a field header. Set these to a non-empty value. request = permission_service.TransferOwnershipRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.transfer_ownership), - '__call__') as call: + type(client.transport.transfer_ownership), "__call__" + ) as call: call.return_value = permission_service.TransferOwnershipResponse() client.transfer_ownership(request) @@ -2053,9 +2293,9 @@ def test_transfer_ownership_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2068,13 +2308,15 @@ async def test_transfer_ownership_field_headers_async(): # a field header. Set these to a non-empty value. request = permission_service.TransferOwnershipRequest() - request.name = 'name_value' + request.name = "name_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.transfer_ownership), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(permission_service.TransferOwnershipResponse()) + type(client.transport.transfer_ownership), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.TransferOwnershipResponse() + ) await client.transfer_ownership(request) # Establish that the underlying gRPC stub method was called. @@ -2085,15 +2327,18 @@ async def test_transfer_ownership_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'name=name_value', - ) in kw['metadata'] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] -@pytest.mark.parametrize("request_type", [ - permission_service.CreatePermissionRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + permission_service.CreatePermissionRequest, + dict, + ], +) def test_create_permission_rest(request_type): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2101,18 +2346,23 @@ def test_create_permission_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'parent': 'tunedModels/sample1'} - request_init["permission"] = {'name': 'name_value', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request_init = {"parent": "tunedModels/sample1"} + request_init["permission"] = { + "name": "name_value", + "grantee_type": 1, + "email_address": "email_address_value", + "role": 1, + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = gag_permission.Permission( - name='name_value', - grantee_type=gag_permission.Permission.GranteeType.USER, - email_address='email_address_value', - role=gag_permission.Permission.Role.OWNER, + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, ) # Wrap the value into a proper Response obj @@ -2121,70 +2371,78 @@ def test_create_permission_rest(request_type): pb_return_value = gag_permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.create_permission(request) # Establish that the response is the type that we expect. assert isinstance(response, gag_permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == gag_permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == gag_permission.Permission.Role.OWNER -def test_create_permission_rest_required_fields(request_type=permission_service.CreatePermissionRequest): +def test_create_permission_rest_required_fields( + request_type=permission_service.CreatePermissionRequest, +): transport_class = transports.PermissionServiceRestTransport request_init = {} request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_permission._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = 'parent_value' + jsonified_request["parent"] = "parent_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).create_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_permission._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "parent" in jsonified_request - assert jsonified_request["parent"] == 'parent_value' + assert jsonified_request["parent"] == "parent_value" client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = gag_permission.Permission() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -2193,39 +2451,56 @@ def test_create_permission_rest_required_fields(request_type=permission_service. pb_return_value = gag_permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.create_permission(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_create_permission_rest_unset_required_fields(): - transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.create_permission._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("parent", "permission", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "permission", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_create_permission_rest_interceptors(null_interceptor): transport = transports.PermissionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) client = PermissionServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "post_create_permission") as post, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_create_permission") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_create_permission" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_create_permission" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = permission_service.CreatePermissionRequest.pb(permission_service.CreatePermissionRequest()) + pb_message = permission_service.CreatePermissionRequest.pb( + permission_service.CreatePermissionRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -2236,35 +2511,52 @@ def test_create_permission_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = gag_permission.Permission.to_json(gag_permission.Permission()) + req.return_value._content = gag_permission.Permission.to_json( + gag_permission.Permission() + ) request = permission_service.CreatePermissionRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = gag_permission.Permission() - client.create_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.create_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_create_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.CreatePermissionRequest): +def test_create_permission_rest_bad_request( + transport: str = "rest", request_type=permission_service.CreatePermissionRequest +): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'parent': 'tunedModels/sample1'} - request_init["permission"] = {'name': 'name_value', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request_init = {"parent": "tunedModels/sample1"} + request_init["permission"] = { + "name": "name_value", + "grantee_type": 1, + "email_address": "email_address_value", + "role": 1, + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -2280,17 +2572,17 @@ def test_create_permission_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = gag_permission.Permission() # get arguments that satisfy an http rule for this method - sample_request = {'parent': 'tunedModels/sample1'} + sample_request = {"parent": "tunedModels/sample1"} # get truthy value for each flattened field mock_args = dict( - parent='parent_value', - permission=gag_permission.Permission(name='name_value'), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), ) mock_args.update(sample_request) @@ -2299,7 +2591,7 @@ def test_create_permission_rest_flattened(): response_value.status_code = 200 pb_return_value = gag_permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.create_permission(**mock_args) @@ -2308,10 +2600,13 @@ def test_create_permission_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{parent=tunedModels/*}/permissions" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{parent=tunedModels/*}/permissions" % client.transport._host, + args[1], + ) -def test_create_permission_rest_flattened_error(transport: str = 'rest'): +def test_create_permission_rest_flattened_error(transport: str = "rest"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2322,22 +2617,24 @@ def test_create_permission_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.create_permission( permission_service.CreatePermissionRequest(), - parent='parent_value', - permission=gag_permission.Permission(name='name_value'), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), ) def test_create_permission_rest_error(): client = PermissionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - permission_service.GetPermissionRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + permission_service.GetPermissionRequest, + dict, + ], +) def test_get_permission_rest(request_type): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2345,17 +2642,17 @@ def test_get_permission_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request_init = {"name": "tunedModels/sample1/permissions/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = permission.Permission( - name='name_value', - grantee_type=permission.Permission.GranteeType.USER, - email_address='email_address_value', - role=permission.Permission.Role.OWNER, + name="name_value", + grantee_type=permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=permission.Permission.Role.OWNER, ) # Wrap the value into a proper Response obj @@ -2364,68 +2661,76 @@ def test_get_permission_rest(request_type): pb_return_value = permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_permission(request) # Establish that the response is the type that we expect. assert isinstance(response, permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == permission.Permission.Role.OWNER -def test_get_permission_rest_required_fields(request_type=permission_service.GetPermissionRequest): +def test_get_permission_rest_required_fields( + request_type=permission_service.GetPermissionRequest, +): transport_class = transports.PermissionServiceRestTransport request_init = {} request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_permission._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = 'name_value' + jsonified_request["name"] = "name_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).get_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_permission._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "name" in jsonified_request - assert jsonified_request["name"] == 'name_value' + assert jsonified_request["name"] == "name_value" client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = permission.Permission() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "get", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, } transcode.return_value = transcode_result @@ -2435,39 +2740,48 @@ def test_get_permission_rest_required_fields(request_type=permission_service.Get pb_return_value = permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_permission(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_get_permission_rest_unset_required_fields(): - transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.get_permission._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name", ))) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_get_permission_rest_interceptors(null_interceptor): transport = transports.PermissionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) client = PermissionServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "post_get_permission") as post, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_get_permission") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_get_permission" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_get_permission" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = permission_service.GetPermissionRequest.pb(permission_service.GetPermissionRequest()) + pb_message = permission_service.GetPermissionRequest.pb( + permission_service.GetPermissionRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -2478,34 +2792,46 @@ def test_get_permission_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = permission.Permission.to_json(permission.Permission()) + req.return_value._content = permission.Permission.to_json( + permission.Permission() + ) request = permission_service.GetPermissionRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = permission.Permission() - client.get_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.get_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_get_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.GetPermissionRequest): +def test_get_permission_rest_bad_request( + transport: str = "rest", request_type=permission_service.GetPermissionRequest +): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request_init = {"name": "tunedModels/sample1/permissions/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -2521,16 +2847,16 @@ def test_get_permission_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = permission.Permission() # get arguments that satisfy an http rule for this method - sample_request = {'name': 'tunedModels/sample1/permissions/sample2'} + sample_request = {"name": "tunedModels/sample1/permissions/sample2"} # get truthy value for each flattened field mock_args = dict( - name='name_value', + name="name_value", ) mock_args.update(sample_request) @@ -2539,7 +2865,7 @@ def test_get_permission_rest_flattened(): response_value.status_code = 200 pb_return_value = permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.get_permission(**mock_args) @@ -2548,10 +2874,13 @@ def test_get_permission_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{name=tunedModels/*/permissions/*}" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{name=tunedModels/*/permissions/*}" % client.transport._host, + args[1], + ) -def test_get_permission_rest_flattened_error(transport: str = 'rest'): +def test_get_permission_rest_flattened_error(transport: str = "rest"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2562,21 +2891,23 @@ def test_get_permission_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.get_permission( permission_service.GetPermissionRequest(), - name='name_value', + name="name_value", ) def test_get_permission_rest_error(): client = PermissionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - permission_service.ListPermissionsRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + permission_service.ListPermissionsRequest, + dict, + ], +) def test_list_permissions_rest(request_type): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2584,14 +2915,14 @@ def test_list_permissions_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'parent': 'tunedModels/sample1'} + request_init = {"parent": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = permission_service.ListPermissionsResponse( - next_page_token='next_page_token_value', + next_page_token="next_page_token_value", ) # Wrap the value into a proper Response obj @@ -2600,109 +2931,141 @@ def test_list_permissions_rest(request_type): pb_return_value = permission_service.ListPermissionsResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_permissions(request) # Establish that the response is the type that we expect. assert isinstance(response, pagers.ListPermissionsPager) - assert response.next_page_token == 'next_page_token_value' + assert response.next_page_token == "next_page_token_value" -def test_list_permissions_rest_required_fields(request_type=permission_service.ListPermissionsRequest): +def test_list_permissions_rest_required_fields( + request_type=permission_service.ListPermissionsRequest, +): transport_class = transports.PermissionServiceRestTransport request_init = {} request_init["parent"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).list_permissions._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_permissions._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["parent"] = 'parent_value' + jsonified_request["parent"] = "parent_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).list_permissions._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_permissions._get_unset_required_fields(jsonified_request) # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("page_size", "page_token", )) + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "parent" in jsonified_request - assert jsonified_request["parent"] == 'parent_value' + assert jsonified_request["parent"] == "parent_value" client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = permission_service.ListPermissionsResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "get", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - pb_return_value = permission_service.ListPermissionsResponse.pb(return_value) + pb_return_value = permission_service.ListPermissionsResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_permissions(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_list_permissions_rest_unset_required_fields(): - transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.list_permissions._get_unset_required_fields({}) - assert set(unset_fields) == (set(("pageSize", "pageToken", )) & set(("parent", ))) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_list_permissions_rest_interceptors(null_interceptor): transport = transports.PermissionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) client = PermissionServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "post_list_permissions") as post, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_list_permissions") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_list_permissions" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_list_permissions" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = permission_service.ListPermissionsRequest.pb(permission_service.ListPermissionsRequest()) + pb_message = permission_service.ListPermissionsRequest.pb( + permission_service.ListPermissionsRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -2713,34 +3076,46 @@ def test_list_permissions_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = permission_service.ListPermissionsResponse.to_json(permission_service.ListPermissionsResponse()) + req.return_value._content = permission_service.ListPermissionsResponse.to_json( + permission_service.ListPermissionsResponse() + ) request = permission_service.ListPermissionsRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = permission_service.ListPermissionsResponse() - client.list_permissions(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.list_permissions( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_list_permissions_rest_bad_request(transport: str = 'rest', request_type=permission_service.ListPermissionsRequest): +def test_list_permissions_rest_bad_request( + transport: str = "rest", request_type=permission_service.ListPermissionsRequest +): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'parent': 'tunedModels/sample1'} + request_init = {"parent": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -2756,16 +3131,16 @@ def test_list_permissions_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = permission_service.ListPermissionsResponse() # get arguments that satisfy an http rule for this method - sample_request = {'parent': 'tunedModels/sample1'} + sample_request = {"parent": "tunedModels/sample1"} # get truthy value for each flattened field mock_args = dict( - parent='parent_value', + parent="parent_value", ) mock_args.update(sample_request) @@ -2774,7 +3149,7 @@ def test_list_permissions_rest_flattened(): response_value.status_code = 200 pb_return_value = permission_service.ListPermissionsResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.list_permissions(**mock_args) @@ -2783,10 +3158,13 @@ def test_list_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{parent=tunedModels/*}/permissions" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{parent=tunedModels/*}/permissions" % client.transport._host, + args[1], + ) -def test_list_permissions_rest_flattened_error(transport: str = 'rest'): +def test_list_permissions_rest_flattened_error(transport: str = "rest"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2797,20 +3175,20 @@ def test_list_permissions_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.list_permissions( permission_service.ListPermissionsRequest(), - parent='parent_value', + parent="parent_value", ) -def test_list_permissions_rest_pager(transport: str = 'rest'): +def test_list_permissions_rest_pager(transport: str = "rest"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # TODO(kbandes): remove this mock unless there's a good reason for it. - #with mock.patch.object(path_template, 'transcode') as transcode: + # with mock.patch.object(path_template, 'transcode') as transcode: # Set the response as a series of pages response = ( permission_service.ListPermissionsResponse( @@ -2819,17 +3197,17 @@ def test_list_permissions_rest_pager(transport: str = 'rest'): permission.Permission(), permission.Permission(), ], - next_page_token='abc', + next_page_token="abc", ), permission_service.ListPermissionsResponse( permissions=[], - next_page_token='def', + next_page_token="def", ), permission_service.ListPermissionsResponse( permissions=[ permission.Permission(), ], - next_page_token='ghi', + next_page_token="ghi", ), permission_service.ListPermissionsResponse( permissions=[ @@ -2842,31 +3220,35 @@ def test_list_permissions_rest_pager(transport: str = 'rest'): response = response + response # Wrap the values into proper Response objs - response = tuple(permission_service.ListPermissionsResponse.to_json(x) for x in response) + response = tuple( + permission_service.ListPermissionsResponse.to_json(x) for x in response + ) return_values = tuple(Response() for i in response) for return_val, response_val in zip(return_values, response): - return_val._content = response_val.encode('UTF-8') + return_val._content = response_val.encode("UTF-8") return_val.status_code = 200 req.side_effect = return_values - sample_request = {'parent': 'tunedModels/sample1'} + sample_request = {"parent": "tunedModels/sample1"} pager = client.list_permissions(request=sample_request) results = list(pager) assert len(results) == 6 - assert all(isinstance(i, permission.Permission) - for i in results) + assert all(isinstance(i, permission.Permission) for i in results) pages = list(client.list_permissions(request=sample_request).pages) - for page_, token in zip(pages, ['abc','def','ghi', '']): + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -@pytest.mark.parametrize("request_type", [ - permission_service.UpdatePermissionRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + permission_service.UpdatePermissionRequest, + dict, + ], +) def test_update_permission_rest(request_type): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2874,18 +3256,23 @@ def test_update_permission_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'permission': {'name': 'tunedModels/sample1/permissions/sample2'}} - request_init["permission"] = {'name': 'tunedModels/sample1/permissions/sample2', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request_init = {"permission": {"name": "tunedModels/sample1/permissions/sample2"}} + request_init["permission"] = { + "name": "tunedModels/sample1/permissions/sample2", + "grantee_type": 1, + "email_address": "email_address_value", + "role": 1, + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = gag_permission.Permission( - name='name_value', - grantee_type=gag_permission.Permission.GranteeType.USER, - email_address='email_address_value', - role=gag_permission.Permission.Role.OWNER, + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, ) # Wrap the value into a proper Response obj @@ -2894,67 +3281,75 @@ def test_update_permission_rest(request_type): pb_return_value = gag_permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.update_permission(request) # Establish that the response is the type that we expect. assert isinstance(response, gag_permission.Permission) - assert response.name == 'name_value' + assert response.name == "name_value" assert response.grantee_type == gag_permission.Permission.GranteeType.USER - assert response.email_address == 'email_address_value' + assert response.email_address == "email_address_value" assert response.role == gag_permission.Permission.Role.OWNER -def test_update_permission_rest_required_fields(request_type=permission_service.UpdatePermissionRequest): +def test_update_permission_rest_required_fields( + request_type=permission_service.UpdatePermissionRequest, +): transport_class = transports.PermissionServiceRestTransport request_init = {} request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_permission._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).update_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_permission._get_unset_required_fields(jsonified_request) # Check that path parameters and body parameters are not mixing in. - assert not set(unset_fields) - set(("update_mask", )) + assert not set(unset_fields) - set(("update_mask",)) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = gag_permission.Permission() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "patch", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -2963,39 +3358,56 @@ def test_update_permission_rest_required_fields(request_type=permission_service. pb_return_value = gag_permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.update_permission(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_update_permission_rest_unset_required_fields(): - transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.update_permission._get_unset_required_fields({}) - assert set(unset_fields) == (set(("updateMask", )) & set(("permission", "updateMask", ))) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "permission", + "updateMask", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_update_permission_rest_interceptors(null_interceptor): transport = transports.PermissionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) client = PermissionServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "post_update_permission") as post, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_update_permission") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_update_permission" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_update_permission" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = permission_service.UpdatePermissionRequest.pb(permission_service.UpdatePermissionRequest()) + pb_message = permission_service.UpdatePermissionRequest.pb( + permission_service.UpdatePermissionRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -3006,35 +3418,52 @@ def test_update_permission_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = gag_permission.Permission.to_json(gag_permission.Permission()) + req.return_value._content = gag_permission.Permission.to_json( + gag_permission.Permission() + ) request = permission_service.UpdatePermissionRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = gag_permission.Permission() - client.update_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.update_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_update_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.UpdatePermissionRequest): +def test_update_permission_rest_bad_request( + transport: str = "rest", request_type=permission_service.UpdatePermissionRequest +): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'permission': {'name': 'tunedModels/sample1/permissions/sample2'}} - request_init["permission"] = {'name': 'tunedModels/sample1/permissions/sample2', 'grantee_type': 1, 'email_address': 'email_address_value', 'role': 1} + request_init = {"permission": {"name": "tunedModels/sample1/permissions/sample2"}} + request_init["permission"] = { + "name": "tunedModels/sample1/permissions/sample2", + "grantee_type": 1, + "email_address": "email_address_value", + "role": 1, + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -3050,17 +3479,19 @@ def test_update_permission_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = gag_permission.Permission() # get arguments that satisfy an http rule for this method - sample_request = {'permission': {'name': 'tunedModels/sample1/permissions/sample2'}} + sample_request = { + "permission": {"name": "tunedModels/sample1/permissions/sample2"} + } # get truthy value for each flattened field mock_args = dict( - permission=gag_permission.Permission(name='name_value'), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) mock_args.update(sample_request) @@ -3069,7 +3500,7 @@ def test_update_permission_rest_flattened(): response_value.status_code = 200 pb_return_value = gag_permission.Permission.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.update_permission(**mock_args) @@ -3078,10 +3509,14 @@ def test_update_permission_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{permission.name=tunedModels/*/permissions/*}" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{permission.name=tunedModels/*/permissions/*}" + % client.transport._host, + args[1], + ) -def test_update_permission_rest_flattened_error(transport: str = 'rest'): +def test_update_permission_rest_flattened_error(transport: str = "rest"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3092,22 +3527,24 @@ def test_update_permission_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.update_permission( permission_service.UpdatePermissionRequest(), - permission=gag_permission.Permission(name='name_value'), - update_mask=field_mask_pb2.FieldMask(paths=['paths_value']), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), ) def test_update_permission_rest_error(): client = PermissionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - permission_service.DeletePermissionRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + permission_service.DeletePermissionRequest, + dict, + ], +) def test_delete_permission_rest(request_type): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3115,20 +3552,20 @@ def test_delete_permission_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request_init = {"name": "tunedModels/sample1/permissions/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = None # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - json_return_value = '' + json_return_value = "" - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.delete_permission(request) @@ -3136,94 +3573,110 @@ def test_delete_permission_rest(request_type): assert response is None -def test_delete_permission_rest_required_fields(request_type=permission_service.DeletePermissionRequest): +def test_delete_permission_rest_required_fields( + request_type=permission_service.DeletePermissionRequest, +): transport_class = transports.PermissionServiceRestTransport request_init = {} request_init["name"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_permission._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = 'name_value' + jsonified_request["name"] = "name_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).delete_permission._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_permission._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "name" in jsonified_request - assert jsonified_request["name"] == 'name_value' + assert jsonified_request["name"] == "name_value" client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = None # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "delete", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, } transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - json_return_value = '' + json_return_value = "" - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.delete_permission(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_delete_permission_rest_unset_required_fields(): - transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.delete_permission._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name", ))) + assert set(unset_fields) == (set(()) & set(("name",))) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_delete_permission_rest_interceptors(null_interceptor): transport = transports.PermissionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) client = PermissionServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_delete_permission") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_delete_permission" + ) as pre: pre.assert_not_called() - pb_message = permission_service.DeletePermissionRequest.pb(permission_service.DeletePermissionRequest()) + pb_message = permission_service.DeletePermissionRequest.pb( + permission_service.DeletePermissionRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -3236,29 +3689,39 @@ def test_delete_permission_rest_interceptors(null_interceptor): req.return_value.request = PreparedRequest() request = permission_service.DeletePermissionRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata - client.delete_permission(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.delete_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() -def test_delete_permission_rest_bad_request(transport: str = 'rest', request_type=permission_service.DeletePermissionRequest): +def test_delete_permission_rest_bad_request( + transport: str = "rest", request_type=permission_service.DeletePermissionRequest +): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1/permissions/sample2'} + request_init = {"name": "tunedModels/sample1/permissions/sample2"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -3274,24 +3737,24 @@ def test_delete_permission_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = None # get arguments that satisfy an http rule for this method - sample_request = {'name': 'tunedModels/sample1/permissions/sample2'} + sample_request = {"name": "tunedModels/sample1/permissions/sample2"} # get truthy value for each flattened field mock_args = dict( - name='name_value', + name="name_value", ) mock_args.update(sample_request) # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 200 - json_return_value = '' - response_value._content = json_return_value.encode('UTF-8') + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.delete_permission(**mock_args) @@ -3300,10 +3763,13 @@ def test_delete_permission_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{name=tunedModels/*/permissions/*}" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{name=tunedModels/*/permissions/*}" % client.transport._host, + args[1], + ) -def test_delete_permission_rest_flattened_error(transport: str = 'rest'): +def test_delete_permission_rest_flattened_error(transport: str = "rest"): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -3314,21 +3780,23 @@ def test_delete_permission_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.delete_permission( permission_service.DeletePermissionRequest(), - name='name_value', + name="name_value", ) def test_delete_permission_rest_error(): client = PermissionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - permission_service.TransferOwnershipRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + permission_service.TransferOwnershipRequest, + dict, + ], +) def test_transfer_ownership_rest(request_type): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -3336,14 +3804,13 @@ def test_transfer_ownership_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1'} + request_init = {"name": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = permission_service.TransferOwnershipResponse( - ) + return_value = permission_service.TransferOwnershipResponse() # Wrap the value into a proper Response obj response_value = Response() @@ -3351,7 +3818,7 @@ def test_transfer_ownership_rest(request_type): pb_return_value = permission_service.TransferOwnershipResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.transfer_ownership(request) @@ -3359,7 +3826,9 @@ def test_transfer_ownership_rest(request_type): assert isinstance(response, permission_service.TransferOwnershipResponse) -def test_transfer_ownership_rest_required_fields(request_type=permission_service.TransferOwnershipRequest): +def test_transfer_ownership_rest_required_fields( + request_type=permission_service.TransferOwnershipRequest, +): transport_class = transports.PermissionServiceRestTransport request_init = {} @@ -3367,95 +3836,120 @@ def test_transfer_ownership_rest_required_fields(request_type=permission_service request_init["email_address"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).transfer_ownership._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).transfer_ownership._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["name"] = 'name_value' - jsonified_request["emailAddress"] = 'email_address_value' + jsonified_request["name"] = "name_value" + jsonified_request["emailAddress"] = "email_address_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).transfer_ownership._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).transfer_ownership._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "name" in jsonified_request - assert jsonified_request["name"] == 'name_value' + assert jsonified_request["name"] == "name_value" assert "emailAddress" in jsonified_request - assert jsonified_request["emailAddress"] == 'email_address_value' + assert jsonified_request["emailAddress"] == "email_address_value" client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = permission_service.TransferOwnershipResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() response_value.status_code = 200 - pb_return_value = permission_service.TransferOwnershipResponse.pb(return_value) + pb_return_value = permission_service.TransferOwnershipResponse.pb( + return_value + ) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.transfer_ownership(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_transfer_ownership_rest_unset_required_fields(): - transport = transports.PermissionServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.transfer_ownership._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("name", "emailAddress", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "name", + "emailAddress", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_transfer_ownership_rest_interceptors(null_interceptor): transport = transports.PermissionServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.PermissionServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) client = PermissionServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "post_transfer_ownership") as post, \ - mock.patch.object(transports.PermissionServiceRestInterceptor, "pre_transfer_ownership") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_transfer_ownership" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_transfer_ownership" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = permission_service.TransferOwnershipRequest.pb(permission_service.TransferOwnershipRequest()) + pb_message = permission_service.TransferOwnershipRequest.pb( + permission_service.TransferOwnershipRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -3466,34 +3960,48 @@ def test_transfer_ownership_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = permission_service.TransferOwnershipResponse.to_json(permission_service.TransferOwnershipResponse()) + req.return_value._content = ( + permission_service.TransferOwnershipResponse.to_json( + permission_service.TransferOwnershipResponse() + ) + ) request = permission_service.TransferOwnershipRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = permission_service.TransferOwnershipResponse() - client.transfer_ownership(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.transfer_ownership( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_transfer_ownership_rest_bad_request(transport: str = 'rest', request_type=permission_service.TransferOwnershipRequest): +def test_transfer_ownership_rest_bad_request( + transport: str = "rest", request_type=permission_service.TransferOwnershipRequest +): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'name': 'tunedModels/sample1'} + request_init = {"name": "tunedModels/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -3504,8 +4012,7 @@ def test_transfer_ownership_rest_bad_request(transport: str = 'rest', request_ty def test_transfer_ownership_rest_error(): client = PermissionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -3547,8 +4054,7 @@ def test_credentials_transport_error(): options.api_key = "api_key" with pytest.raises(ValueError): client = PermissionServiceClient( - client_options=options, - credentials=ga_credentials.AnonymousCredentials() + client_options=options, credentials=ga_credentials.AnonymousCredentials() ) # It is an error to provide scopes and a transport instance. @@ -3570,6 +4076,7 @@ def test_transport_instance(): client = PermissionServiceClient(transport=transport) assert client.transport is transport + def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.PermissionServiceGrpcTransport( @@ -3584,28 +4091,37 @@ def test_transport_get_channel(): channel = transport.grpc_channel assert channel -@pytest.mark.parametrize("transport_class", [ - transports.PermissionServiceGrpcTransport, - transports.PermissionServiceGrpcAsyncIOTransport, - transports.PermissionServiceRestTransport, -]) + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + transports.PermissionServiceRestTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() -@pytest.mark.parametrize("transport_name", [ - "grpc", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) def test_transport_kind(transport_name): transport = PermissionServiceClient.get_transport_class(transport_name)( credentials=ga_credentials.AnonymousCredentials(), ) assert transport.kind == transport_name + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = PermissionServiceClient( @@ -3616,18 +4132,21 @@ def test_transport_grpc_default(): transports.PermissionServiceGrpcTransport, ) + def test_permission_service_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.PermissionServiceTransport( credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_permission_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport.__init__') as Transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.PermissionServiceTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -3636,12 +4155,12 @@ def test_permission_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'create_permission', - 'get_permission', - 'list_permissions', - 'update_permission', - 'delete_permission', - 'transfer_ownership', + "create_permission", + "get_permission", + "list_permissions", + "update_permission", + "delete_permission", + "transfer_ownership", ) for method in methods: with pytest.raises(NotImplementedError): @@ -3652,7 +4171,7 @@ def test_permission_service_base_transport(): # Catch all for all remaining methods and properties remainder = [ - 'kind', + "kind", ] for r in remainder: with pytest.raises(NotImplementedError): @@ -3661,24 +4180,30 @@ def test_permission_service_base_transport(): def test_permission_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.PermissionServiceTransport( credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", + load_creds.assert_called_once_with( + "credentials.json", scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id="octopus", ) def test_permission_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1beta3.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.PermissionServiceTransport() @@ -3687,13 +4212,12 @@ def test_permission_service_base_transport_with_adc(): def test_permission_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) PermissionServiceClient() adc.assert_called_once_with( scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id=None, ) @@ -3708,7 +4232,7 @@ def test_permission_service_auth_adc(): def test_permission_service_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( @@ -3727,47 +4251,45 @@ def test_permission_service_transport_auth_adc(transport_class): ], ) def test_permission_service_transport_auth_gdch_credentials(transport_class): - host = 'https://language.com' - api_audience_tests = [None, 'https://language2.com'] - api_audience_expect = [host, 'https://language2.com'] + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] for t, e in zip(api_audience_tests, api_audience_expect): - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: gdch_mock = mock.MagicMock() - type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) adc.return_value = (gdch_mock, None) transport_class(host=host, api_audience=t) - gdch_mock.with_gdch_audience.assert_called_once_with( - e - ) + gdch_mock.with_gdch_audience.assert_called_once_with(e) @pytest.mark.parametrize( "transport_class,grpc_helpers", [ (transports.PermissionServiceGrpcTransport, grpc_helpers), - (transports.PermissionServiceGrpcAsyncIOTransport, grpc_helpers_async) + (transports.PermissionServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) def test_permission_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: creds = ga_credentials.AnonymousCredentials() adc.return_value = (creds, None) - transport_class( - quota_project_id="octopus", - scopes=["1", "2"] - ) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) create_channel.assert_called_with( "generativelanguage.googleapis.com:443", credentials=creds, credentials_file=None, quota_project_id="octopus", - default_scopes=( -), + default_scopes=(), scopes=["1", "2"], default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -3778,10 +4300,14 @@ def test_permission_service_transport_create_channel(transport_class, grpc_helpe ) -@pytest.mark.parametrize("transport_class", [transports.PermissionServiceGrpcTransport, transports.PermissionServiceGrpcAsyncIOTransport]) -def test_permission_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) +def test_permission_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = ga_credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -3790,7 +4316,7 @@ def test_permission_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", @@ -3811,61 +4337,77 @@ def test_permission_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) + def test_permission_service_http_transport_client_cert_source_for_mtls(): cred = ga_credentials.AnonymousCredentials() - with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: - transports.PermissionServiceRestTransport ( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.PermissionServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback ) mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_permission_service_host_no_port(transport_name): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), - transport=transport_name, + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_permission_service_host_with_port(transport_name): client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:8000' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com:8000' + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" ) -@pytest.mark.parametrize("transport_name", [ - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) def test_permission_service_client_transport_session_collision(transport_name): creds1 = ga_credentials.AnonymousCredentials() creds2 = ga_credentials.AnonymousCredentials() @@ -3895,8 +4437,10 @@ def test_permission_service_client_transport_session_collision(transport_name): session1 = client1.transport.transfer_ownership._session session2 = client2.transport.transfer_ownership._session assert session1 != session2 + + def test_permission_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PermissionServiceGrpcTransport( @@ -3909,7 +4453,7 @@ def test_permission_service_grpc_transport_channel(): def test_permission_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PermissionServiceGrpcAsyncIOTransport( @@ -3923,12 +4467,22 @@ def test_permission_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.PermissionServiceGrpcTransport, transports.PermissionServiceGrpcAsyncIOTransport]) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) def test_permission_service_transport_channel_mtls_with_client_cert_source( - transport_class + transport_class, ): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3937,7 +4491,7 @@ def test_permission_service_transport_channel_mtls_with_client_cert_source( cred = ga_credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -3967,17 +4521,23 @@ def test_permission_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.PermissionServiceGrpcTransport, transports.PermissionServiceGrpcAsyncIOTransport]) -def test_permission_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) +def test_permission_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -4008,7 +4568,10 @@ def test_permission_service_transport_channel_mtls_with_adc( def test_permission_path(): tuned_model = "squid" permission = "clam" - expected = "tunedModels/{tuned_model}/permissions/{permission}".format(tuned_model=tuned_model, permission=permission, ) + expected = "tunedModels/{tuned_model}/permissions/{permission}".format( + tuned_model=tuned_model, + permission=permission, + ) actual = PermissionServiceClient.permission_path(tuned_model, permission) assert expected == actual @@ -4024,9 +4587,12 @@ def test_parse_permission_path(): actual = PermissionServiceClient.parse_permission_path(path) assert expected == actual + def test_tuned_model_path(): tuned_model = "oyster" - expected = "tunedModels/{tuned_model}".format(tuned_model=tuned_model, ) + expected = "tunedModels/{tuned_model}".format( + tuned_model=tuned_model, + ) actual = PermissionServiceClient.tuned_model_path(tuned_model) assert expected == actual @@ -4041,9 +4607,12 @@ def test_parse_tuned_model_path(): actual = PermissionServiceClient.parse_tuned_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "cuttlefish" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = PermissionServiceClient.common_billing_account_path(billing_account) assert expected == actual @@ -4058,9 +4627,12 @@ def test_parse_common_billing_account_path(): actual = PermissionServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "winkle" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format( + folder=folder, + ) actual = PermissionServiceClient.common_folder_path(folder) assert expected == actual @@ -4075,9 +4647,12 @@ def test_parse_common_folder_path(): actual = PermissionServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "scallop" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format( + organization=organization, + ) actual = PermissionServiceClient.common_organization_path(organization) assert expected == actual @@ -4092,9 +4667,12 @@ def test_parse_common_organization_path(): actual = PermissionServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "squid" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format( + project=project, + ) actual = PermissionServiceClient.common_project_path(project) assert expected == actual @@ -4109,10 +4687,14 @@ def test_parse_common_project_path(): actual = PermissionServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "whelk" location = "octopus" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) actual = PermissionServiceClient.common_location_path(project, location) assert expected == actual @@ -4132,14 +4714,18 @@ def test_parse_common_location_path(): def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.PermissionServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PermissionServiceTransport, "_prep_wrapped_messages" + ) as prep: client = PermissionServiceClient( credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.PermissionServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.PermissionServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = PermissionServiceClient.get_transport_class() transport = transport_class( credentials=ga_credentials.AnonymousCredentials(), @@ -4147,13 +4733,16 @@ def test_client_with_default_client_info(): ) prep.assert_called_once_with(client_info) + @pytest.mark.asyncio async def test_transport_close_async(): client = PermissionServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: async with client: close.assert_not_called() close.assert_called_once() @@ -4167,23 +4756,24 @@ def test_transport_close(): for transport, close_name in transports.items(): client = PermissionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) - with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: with client: close.assert_not_called() close.assert_called_once() + def test_client_ctx(): transports = [ - 'rest', - 'grpc', + "rest", + "grpc", ] for transport in transports: client = PermissionServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -4192,10 +4782,17 @@ def test_client_ctx(): pass close.assert_called() -@pytest.mark.parametrize("client_class,transport_class", [ - (PermissionServiceClient, transports.PermissionServiceGrpcTransport), - (PermissionServiceAsyncClient, transports.PermissionServiceGrpcAsyncIOTransport), -]) + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + ), + ], +) def test_api_key_credentials(client_class, transport_class): with mock.patch.object( google.auth._default, "get_api_key_credentials", create=True diff --git a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py similarity index 70% rename from owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py rename to packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py index 045435db1fc0..87b063b0a9f6 100644 --- a/owl-bot-staging/google-ai-generativelanguage/v1beta3/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta3/test_text_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import os + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -21,36 +22,33 @@ except ImportError: # pragma: NO COVER import mock -import grpc -from grpc.experimental import aio from collections.abc import Iterable -from google.protobuf import json_format import json import math -import pytest -from proto.marshal.rules.dates import DurationRule, TimestampRule -from proto.marshal.rules import wrappers -from requests import Response -from requests import Request, PreparedRequest -from requests.sessions import Session -from google.protobuf import json_format -from google.ai.generativelanguage_v1beta3.services.text_service import TextServiceAsyncClient -from google.ai.generativelanguage_v1beta3.services.text_service import TextServiceClient -from google.ai.generativelanguage_v1beta3.services.text_service import transports -from google.ai.generativelanguage_v1beta3.types import safety -from google.ai.generativelanguage_v1beta3.types import text_service +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template from google.api_core import client_options from google.api_core import exceptions as core_exceptions -from google.api_core import gapic_v1 -from google.api_core import grpc_helpers -from google.api_core import grpc_helpers_async -from google.api_core import path_template +import google.auth from google.auth import credentials as ga_credentials from google.auth.exceptions import MutualTLSChannelError -from google.longrunning import operations_pb2 # type: ignore +from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account -import google.auth +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +from google.ai.generativelanguage_v1beta3.services.text_service import ( + TextServiceAsyncClient, + TextServiceClient, + transports, +) +from google.ai.generativelanguage_v1beta3.types import safety, text_service def client_cert_source_callback(): @@ -61,7 +59,11 @@ def client_cert_source_callback(): # This method modifies the default endpoint so the client can produce a different # mtls endpoint for endpoint testing purposes. def modify_default_endpoint(client): - return "foo.googleapis.com" if ("localhost" in client.DEFAULT_ENDPOINT) else client.DEFAULT_ENDPOINT + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) def test__get_default_mtls_endpoint(): @@ -72,21 +74,37 @@ def test__get_default_mtls_endpoint(): non_googleapi = "api.example.com" assert TextServiceClient._get_default_mtls_endpoint(None) is None - assert TextServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint - assert TextServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) == api_mtls_endpoint - assert TextServiceClient._get_default_mtls_endpoint(sandbox_endpoint) == sandbox_mtls_endpoint - assert TextServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) == sandbox_mtls_endpoint + assert ( + TextServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + TextServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + TextServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + TextServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) assert TextServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class,transport_name", [ - (TextServiceClient, "grpc"), - (TextServiceAsyncClient, "grpc_asyncio"), - (TextServiceClient, "rest"), -]) +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), + ], +) def test_text_service_client_from_service_account_info(client_class, transport_name): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: factory.return_value = creds info = {"valid": True} client = client_class.from_service_account_info(info, transport=transport_name) @@ -94,52 +112,68 @@ def test_text_service_client_from_service_account_info(client_class, transport_n assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_class,transport_name", [ - (transports.TextServiceGrpcTransport, "grpc"), - (transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (transports.TextServiceRestTransport, "rest"), -]) -def test_text_service_client_service_account_always_use_jwt(transport_class, transport_name): - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.TextServiceGrpcTransport, "grpc"), + (transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.TextServiceRestTransport, "rest"), + ], +) +def test_text_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=True) use_jwt.assert_called_once_with(True) - with mock.patch.object(service_account.Credentials, 'with_always_use_jwt_access', create=True) as use_jwt: + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: creds = service_account.Credentials(None, None, None) transport = transport_class(credentials=creds, always_use_jwt_access=False) use_jwt.assert_not_called() -@pytest.mark.parametrize("client_class,transport_name", [ - (TextServiceClient, "grpc"), - (TextServiceAsyncClient, "grpc_asyncio"), - (TextServiceClient, "rest"), -]) +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), + ], +) def test_text_service_client_from_service_account_file(client_class, transport_name): creds = ga_credentials.AnonymousCredentials() - with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: factory.return_value = creds - client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) - client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name) + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) assert client.transport._credentials == creds assert isinstance(client, client_class) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else - 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) @@ -155,30 +189,43 @@ def test_text_service_client_get_transport_class(): assert transport == transports.TextServiceGrpcTransport -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (TextServiceClient, transports.TextServiceRestTransport, "rest"), -]) -@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) -@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) -def test_text_service_client_client_options(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient) +) +@mock.patch.object( + TextServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TextServiceAsyncClient), +) +def test_text_service_client_client_options( + client_class, transport_class, transport_name +): # Check that if channel is provided we won't create a new one. - with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: - transport = transport_class( - credentials=ga_credentials.AnonymousCredentials() - ) + with mock.patch.object(TextServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) client = client_class(transport=transport) gtc.assert_not_called() # Check that if channel is provided via str we will create a new one. - with mock.patch.object(TextServiceClient, 'get_transport_class') as gtc: + with mock.patch.object(TextServiceClient, "get_transport_class") as gtc: client = client_class(transport=transport_name) gtc.assert_called() # Check the case api_endpoint is provided. options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name, client_options=options) patched.assert_called_once_with( @@ -196,7 +243,7 @@ def test_text_service_client_client_options(client_class, transport_class, trans # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "never". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -214,7 +261,7 @@ def test_text_service_client_client_options(client_class, transport_class, trans # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is # "always". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -236,13 +283,15 @@ def test_text_service_client_client_options(client_class, transport_class, trans client = client_class(transport=transport_name) # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"}): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): with pytest.raises(ValueError): client = client_class(transport=transport_name) # Check the case quota_project_id is provided options = client_options.ClientOptions(quota_project_id="octopus") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -257,8 +306,10 @@ def test_text_service_client_client_options(client_class, transport_class, trans api_audience=None, ) # Check the case api_endpoint is provided - options = client_options.ClientOptions(api_audience="https://language.googleapis.com") - with mock.patch.object(transport_class, '__init__') as patched: + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -270,29 +321,55 @@ def test_text_service_client_client_options(client_class, transport_class, trans quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, always_use_jwt_access=True, - api_audience="https://language.googleapis.com" - ) - -@pytest.mark.parametrize("client_class,transport_class,transport_name,use_client_cert_env", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "true"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "true"), - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "false"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", "false"), - (TextServiceClient, transports.TextServiceRestTransport, "rest", "true"), - (TextServiceClient, transports.TextServiceRestTransport, "rest", "false"), -]) -@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) -@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "true"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "false"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "true"), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient) +) +@mock.patch.object( + TextServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TextServiceAsyncClient), +) @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) -def test_text_service_client_mtls_env_auto(client_class, transport_class, transport_name, use_client_cert_env): +def test_text_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. # Check the case client_cert_source is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - options = client_options.ClientOptions(client_cert_source=client_cert_source_callback) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) @@ -317,10 +394,18 @@ def test_text_service_client_mtls_env_auto(client_class, transport_class, transp # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=client_cert_source_callback): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): if use_client_cert_env == "false": expected_host = client.DEFAULT_ENDPOINT expected_client_cert_source = None @@ -343,9 +428,14 @@ def test_text_service_client_mtls_env_auto(client_class, transport_class, transp ) # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env}): - with mock.patch.object(transport_class, '__init__') as patched: - with mock.patch("google.auth.transport.mtls.has_default_client_cert_source", return_value=False): + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): patched.return_value = None client = client_class(transport=transport_name) patched.assert_called_once_with( @@ -361,19 +451,27 @@ def test_text_service_client_mtls_env_auto(client_class, transport_class, transp ) -@pytest.mark.parametrize("client_class", [ - TextServiceClient, TextServiceAsyncClient -]) -@mock.patch.object(TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient)) -@mock.patch.object(TextServiceAsyncClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceAsyncClient)) +@pytest.mark.parametrize("client_class", [TextServiceClient, TextServiceAsyncClient]) +@mock.patch.object( + TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient) +) +@mock.patch.object( + TextServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TextServiceAsyncClient), +) def test_text_service_client_get_mtls_endpoint_and_cert_source(client_class): mock_client_cert_source = mock.Mock() # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source == mock_client_cert_source @@ -381,8 +479,12 @@ def test_text_service_client_get_mtls_endpoint_and_cert_source(client_class): with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): mock_client_cert_source = mock.Mock() mock_api_endpoint = "foo" - options = client_options.ClientOptions(client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint) - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source(options) + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) assert api_endpoint == mock_api_endpoint assert cert_source is None @@ -400,31 +502,52 @@ def test_text_service_client_get_mtls_endpoint_and_cert_source(client_class): # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_ENDPOINT assert cert_source is None # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): - with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True): - with mock.patch('google.auth.transport.mtls.default_client_cert_source', return_value=mock_client_cert_source): - api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT assert cert_source == mock_client_cert_source -@pytest.mark.parametrize("client_class,transport_class,transport_name", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), - (TextServiceClient, transports.TextServiceRestTransport, "rest"), -]) -def test_text_service_client_client_options_scopes(client_class, transport_class, transport_name): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), + ], +) +def test_text_service_client_client_options_scopes( + client_class, transport_class, transport_name +): # Check the case scopes are provided. options = client_options.ClientOptions( scopes=["1", "2"], ) - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -439,18 +562,27 @@ def test_text_service_client_client_options_scopes(client_class, transport_class api_audience=None, ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), - (TextServiceClient, transports.TextServiceRestTransport, "rest", None), -]) -def test_text_service_client_client_options_credentials_file(client_class, transport_class, transport_name, grpc_helpers): + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest", None), + ], +) +def test_text_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -465,12 +597,13 @@ def test_text_service_client_client_options_credentials_file(client_class, trans api_audience=None, ) + def test_text_service_client_client_options_from_dict(): - with mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceGrpcTransport.__init__') as grpc_transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceGrpcTransport.__init__" + ) as grpc_transport: grpc_transport.return_value = None - client = TextServiceClient( - client_options={'api_endpoint': 'squid.clam.whelk'} - ) + client = TextServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) grpc_transport.assert_called_once_with( credentials=None, credentials_file=None, @@ -484,17 +617,25 @@ def test_text_service_client_client_options_from_dict(): ) -@pytest.mark.parametrize("client_class,transport_class,transport_name,grpc_helpers", [ - (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio", grpc_helpers_async), -]) -def test_text_service_client_create_channel_credentials_file(client_class, transport_class, transport_name, grpc_helpers): +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_text_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): # Check the case credentials file is provided. - options = client_options.ClientOptions( - credentials_file="credentials.json" - ) + options = client_options.ClientOptions(credentials_file="credentials.json") - with mock.patch.object(transport_class, '__init__') as patched: + with mock.patch.object(transport_class, "__init__") as patched: patched.return_value = None client = client_class(client_options=options, transport=transport_name) patched.assert_called_once_with( @@ -527,8 +668,7 @@ def test_text_service_client_create_channel_credentials_file(client_class, trans credentials=file_creds, credentials_file=None, quota_project_id=None, - default_scopes=( -), + default_scopes=(), scopes=None, default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -539,11 +679,14 @@ def test_text_service_client_create_channel_credentials_file(client_class, trans ) -@pytest.mark.parametrize("request_type", [ - text_service.GenerateTextRequest, - dict, -]) -def test_generate_text(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + text_service.GenerateTextRequest, + dict, + ], +) +def test_generate_text(request_type, transport: str = "grpc"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -554,12 +697,9 @@ def test_generate_text(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = text_service.GenerateTextResponse( - ) + call.return_value = text_service.GenerateTextResponse() response = client.generate_text(request) # Establish that the underlying gRPC stub method was called. @@ -576,20 +716,21 @@ def test_generate_text_empty_call(): # i.e. request == None and no flattened fields passed, work. client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: client.generate_text() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == text_service.GenerateTextRequest() + @pytest.mark.asyncio -async def test_generate_text_async(transport: str = 'grpc_asyncio', request_type=text_service.GenerateTextRequest): +async def test_generate_text_async( + transport: str = "grpc_asyncio", request_type=text_service.GenerateTextRequest +): client = TextServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -600,12 +741,11 @@ async def test_generate_text_async(transport: str = 'grpc_asyncio', request_type request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.GenerateTextResponse() + ) response = await client.generate_text(request) # Establish that the underlying gRPC stub method was called. @@ -631,12 +771,10 @@ def test_generate_text_field_headers(): # a field header. Set these to a non-empty value. request = text_service.GenerateTextRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: call.return_value = text_service.GenerateTextResponse() client.generate_text(request) @@ -648,9 +786,9 @@ def test_generate_text_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -663,13 +801,13 @@ async def test_generate_text_field_headers_async(): # a field header. Set these to a non-empty value. request = text_service.GenerateTextRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.GenerateTextResponse() + ) await client.generate_text(request) # Establish that the underlying gRPC stub method was called. @@ -680,9 +818,9 @@ async def test_generate_text_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] def test_generate_text_flattened(): @@ -691,16 +829,14 @@ def test_generate_text_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = text_service.GenerateTextResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.generate_text( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), temperature=0.1198, candidate_count=1573, max_output_tokens=1865, @@ -713,10 +849,10 @@ def test_generate_text_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = text_service.TextPrompt(text='text_value') + mock_val = text_service.TextPrompt(text="text_value") assert arg == mock_val assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) arg = args[0].candidate_count @@ -741,8 +877,8 @@ def test_generate_text_flattened_error(): with pytest.raises(ValueError): client.generate_text( text_service.GenerateTextRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), temperature=0.1198, candidate_count=1573, max_output_tokens=1865, @@ -750,6 +886,7 @@ def test_generate_text_flattened_error(): top_k=541, ) + @pytest.mark.asyncio async def test_generate_text_flattened_async(): client = TextServiceAsyncClient( @@ -757,18 +894,18 @@ async def test_generate_text_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.generate_text), - '__call__') as call: + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = text_service.GenerateTextResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.GenerateTextResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.GenerateTextResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.generate_text( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), temperature=0.1198, candidate_count=1573, max_output_tokens=1865, @@ -781,10 +918,10 @@ async def test_generate_text_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = text_service.TextPrompt(text='text_value') + mock_val = text_service.TextPrompt(text="text_value") assert arg == mock_val assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) arg = args[0].candidate_count @@ -798,6 +935,7 @@ async def test_generate_text_flattened_async(): mock_val = 541 assert arg == mock_val + @pytest.mark.asyncio async def test_generate_text_flattened_error_async(): client = TextServiceAsyncClient( @@ -809,8 +947,8 @@ async def test_generate_text_flattened_error_async(): with pytest.raises(ValueError): await client.generate_text( text_service.GenerateTextRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), temperature=0.1198, candidate_count=1573, max_output_tokens=1865, @@ -819,11 +957,14 @@ async def test_generate_text_flattened_error_async(): ) -@pytest.mark.parametrize("request_type", [ - text_service.EmbedTextRequest, - dict, -]) -def test_embed_text(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + text_service.EmbedTextRequest, + dict, + ], +) +def test_embed_text(request_type, transport: str = "grpc"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -834,12 +975,9 @@ def test_embed_text(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = text_service.EmbedTextResponse( - ) + call.return_value = text_service.EmbedTextResponse() response = client.embed_text(request) # Establish that the underlying gRPC stub method was called. @@ -856,20 +994,21 @@ def test_embed_text_empty_call(): # i.e. request == None and no flattened fields passed, work. client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: client.embed_text() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == text_service.EmbedTextRequest() + @pytest.mark.asyncio -async def test_embed_text_async(transport: str = 'grpc_asyncio', request_type=text_service.EmbedTextRequest): +async def test_embed_text_async( + transport: str = "grpc_asyncio", request_type=text_service.EmbedTextRequest +): client = TextServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -880,12 +1019,11 @@ async def test_embed_text_async(transport: str = 'grpc_asyncio', request_type=te request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.EmbedTextResponse() + ) response = await client.embed_text(request) # Establish that the underlying gRPC stub method was called. @@ -911,12 +1049,10 @@ def test_embed_text_field_headers(): # a field header. Set these to a non-empty value. request = text_service.EmbedTextRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: call.return_value = text_service.EmbedTextResponse() client.embed_text(request) @@ -928,9 +1064,9 @@ def test_embed_text_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -943,13 +1079,13 @@ async def test_embed_text_field_headers_async(): # a field header. Set these to a non-empty value. request = text_service.EmbedTextRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.EmbedTextResponse() + ) await client.embed_text(request) # Establish that the underlying gRPC stub method was called. @@ -960,9 +1096,9 @@ async def test_embed_text_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] def test_embed_text_flattened(): @@ -971,16 +1107,14 @@ def test_embed_text_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = text_service.EmbedTextResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.embed_text( - model='model_value', - text='text_value', + model="model_value", + text="text_value", ) # Establish that the underlying call was made with the expected @@ -988,10 +1122,10 @@ def test_embed_text_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].text - mock_val = 'text_value' + mock_val = "text_value" assert arg == mock_val @@ -1005,10 +1139,11 @@ def test_embed_text_flattened_error(): with pytest.raises(ValueError): client.embed_text( text_service.EmbedTextRequest(), - model='model_value', - text='text_value', + model="model_value", + text="text_value", ) + @pytest.mark.asyncio async def test_embed_text_flattened_async(): client = TextServiceAsyncClient( @@ -1016,18 +1151,18 @@ async def test_embed_text_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = text_service.EmbedTextResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.EmbedTextResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.EmbedTextResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.embed_text( - model='model_value', - text='text_value', + model="model_value", + text="text_value", ) # Establish that the underlying call was made with the expected @@ -1035,12 +1170,13 @@ async def test_embed_text_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].text - mock_val = 'text_value' + mock_val = "text_value" assert arg == mock_val + @pytest.mark.asyncio async def test_embed_text_flattened_error_async(): client = TextServiceAsyncClient( @@ -1052,16 +1188,19 @@ async def test_embed_text_flattened_error_async(): with pytest.raises(ValueError): await client.embed_text( text_service.EmbedTextRequest(), - model='model_value', - text='text_value', + model="model_value", + text="text_value", ) -@pytest.mark.parametrize("request_type", [ - text_service.BatchEmbedTextRequest, - dict, -]) -def test_batch_embed_text(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + text_service.BatchEmbedTextRequest, + dict, + ], +) +def test_batch_embed_text(request_type, transport: str = "grpc"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1072,12 +1211,9 @@ def test_batch_embed_text(request_type, transport: str = 'grpc'): request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value = text_service.BatchEmbedTextResponse( - ) + call.return_value = text_service.BatchEmbedTextResponse() response = client.batch_embed_text(request) # Establish that the underlying gRPC stub method was called. @@ -1094,20 +1230,21 @@ def test_batch_embed_text_empty_call(): # i.e. request == None and no flattened fields passed, work. client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: client.batch_embed_text() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == text_service.BatchEmbedTextRequest() + @pytest.mark.asyncio -async def test_batch_embed_text_async(transport: str = 'grpc_asyncio', request_type=text_service.BatchEmbedTextRequest): +async def test_batch_embed_text_async( + transport: str = "grpc_asyncio", request_type=text_service.BatchEmbedTextRequest +): client = TextServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1118,12 +1255,11 @@ async def test_batch_embed_text_async(transport: str = 'grpc_asyncio', request_t request = request_type() # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.BatchEmbedTextResponse( - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.BatchEmbedTextResponse() + ) response = await client.batch_embed_text(request) # Establish that the underlying gRPC stub method was called. @@ -1149,12 +1285,10 @@ def test_batch_embed_text_field_headers(): # a field header. Set these to a non-empty value. request = text_service.BatchEmbedTextRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: call.return_value = text_service.BatchEmbedTextResponse() client.batch_embed_text(request) @@ -1166,9 +1300,9 @@ def test_batch_embed_text_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1181,13 +1315,13 @@ async def test_batch_embed_text_field_headers_async(): # a field header. Set these to a non-empty value. request = text_service.BatchEmbedTextRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_embed_text), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.BatchEmbedTextResponse()) + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.BatchEmbedTextResponse() + ) await client.batch_embed_text(request) # Establish that the underlying gRPC stub method was called. @@ -1198,9 +1332,9 @@ async def test_batch_embed_text_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] def test_batch_embed_text_flattened(): @@ -1209,16 +1343,14 @@ def test_batch_embed_text_flattened(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = text_service.BatchEmbedTextResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.batch_embed_text( - model='model_value', - texts=['texts_value'], + model="model_value", + texts=["texts_value"], ) # Establish that the underlying call was made with the expected @@ -1226,10 +1358,10 @@ def test_batch_embed_text_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].texts - mock_val = ['texts_value'] + mock_val = ["texts_value"] assert arg == mock_val @@ -1243,10 +1375,11 @@ def test_batch_embed_text_flattened_error(): with pytest.raises(ValueError): client.batch_embed_text( text_service.BatchEmbedTextRequest(), - model='model_value', - texts=['texts_value'], + model="model_value", + texts=["texts_value"], ) + @pytest.mark.asyncio async def test_batch_embed_text_flattened_async(): client = TextServiceAsyncClient( @@ -1254,18 +1387,18 @@ async def test_batch_embed_text_flattened_async(): ) # Mock the actual call within the gRPC stub, and fake the request. - with mock.patch.object( - type(client.transport.batch_embed_text), - '__call__') as call: + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: # Designate an appropriate return value for the call. call.return_value = text_service.BatchEmbedTextResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.BatchEmbedTextResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.BatchEmbedTextResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.batch_embed_text( - model='model_value', - texts=['texts_value'], + model="model_value", + texts=["texts_value"], ) # Establish that the underlying call was made with the expected @@ -1273,12 +1406,13 @@ async def test_batch_embed_text_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].texts - mock_val = ['texts_value'] + mock_val = ["texts_value"] assert arg == mock_val + @pytest.mark.asyncio async def test_batch_embed_text_flattened_error_async(): client = TextServiceAsyncClient( @@ -1290,16 +1424,19 @@ async def test_batch_embed_text_flattened_error_async(): with pytest.raises(ValueError): await client.batch_embed_text( text_service.BatchEmbedTextRequest(), - model='model_value', - texts=['texts_value'], + model="model_value", + texts=["texts_value"], ) -@pytest.mark.parametrize("request_type", [ - text_service.CountTextTokensRequest, - dict, -]) -def test_count_text_tokens(request_type, transport: str = 'grpc'): +@pytest.mark.parametrize( + "request_type", + [ + text_service.CountTextTokensRequest, + dict, + ], +) +def test_count_text_tokens(request_type, transport: str = "grpc"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1311,8 +1448,8 @@ def test_count_text_tokens(request_type, transport: str = 'grpc'): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_text_tokens), - '__call__') as call: + type(client.transport.count_text_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = text_service.CountTextTokensResponse( token_count=1193, @@ -1334,20 +1471,23 @@ def test_count_text_tokens_empty_call(): # i.e. request == None and no flattened fields passed, work. client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='grpc', + transport="grpc", ) # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_text_tokens), - '__call__') as call: + type(client.transport.count_text_tokens), "__call__" + ) as call: client.count_text_tokens() call.assert_called() _, args, _ = call.mock_calls[0] assert args[0] == text_service.CountTextTokensRequest() + @pytest.mark.asyncio -async def test_count_text_tokens_async(transport: str = 'grpc_asyncio', request_type=text_service.CountTextTokensRequest): +async def test_count_text_tokens_async( + transport: str = "grpc_asyncio", request_type=text_service.CountTextTokensRequest +): client = TextServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1359,12 +1499,14 @@ async def test_count_text_tokens_async(transport: str = 'grpc_asyncio', request_ # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_text_tokens), - '__call__') as call: + type(client.transport.count_text_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. - call.return_value =grpc_helpers_async.FakeUnaryUnaryCall(text_service.CountTextTokensResponse( - token_count=1193, - )) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.CountTextTokensResponse( + token_count=1193, + ) + ) response = await client.count_text_tokens(request) # Establish that the underlying gRPC stub method was called. @@ -1391,12 +1533,12 @@ def test_count_text_tokens_field_headers(): # a field header. Set these to a non-empty value. request = text_service.CountTextTokensRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_text_tokens), - '__call__') as call: + type(client.transport.count_text_tokens), "__call__" + ) as call: call.return_value = text_service.CountTextTokensResponse() client.count_text_tokens(request) @@ -1408,9 +1550,9 @@ def test_count_text_tokens_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1423,13 +1565,15 @@ async def test_count_text_tokens_field_headers_async(): # a field header. Set these to a non-empty value. request = text_service.CountTextTokensRequest() - request.model = 'model_value' + request.model = "model_value" # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_text_tokens), - '__call__') as call: - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.CountTextTokensResponse()) + type(client.transport.count_text_tokens), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.CountTextTokensResponse() + ) await client.count_text_tokens(request) # Establish that the underlying gRPC stub method was called. @@ -1440,9 +1584,9 @@ async def test_count_text_tokens_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - 'x-goog-request-params', - 'model=model_value', - ) in kw['metadata'] + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] def test_count_text_tokens_flattened(): @@ -1452,15 +1596,15 @@ def test_count_text_tokens_flattened(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_text_tokens), - '__call__') as call: + type(client.transport.count_text_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = text_service.CountTextTokensResponse() # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. client.count_text_tokens( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), ) # Establish that the underlying call was made with the expected @@ -1468,10 +1612,10 @@ def test_count_text_tokens_flattened(): assert len(call.mock_calls) == 1 _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = text_service.TextPrompt(text='text_value') + mock_val = text_service.TextPrompt(text="text_value") assert arg == mock_val @@ -1485,10 +1629,11 @@ def test_count_text_tokens_flattened_error(): with pytest.raises(ValueError): client.count_text_tokens( text_service.CountTextTokensRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), ) + @pytest.mark.asyncio async def test_count_text_tokens_flattened_async(): client = TextServiceAsyncClient( @@ -1497,17 +1642,19 @@ async def test_count_text_tokens_flattened_async(): # Mock the actual call within the gRPC stub, and fake the request. with mock.patch.object( - type(client.transport.count_text_tokens), - '__call__') as call: + type(client.transport.count_text_tokens), "__call__" + ) as call: # Designate an appropriate return value for the call. call.return_value = text_service.CountTextTokensResponse() - call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(text_service.CountTextTokensResponse()) + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.CountTextTokensResponse() + ) # Call the method with a truthy value for each flattened field, # using the keyword arguments to the method. response = await client.count_text_tokens( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), ) # Establish that the underlying call was made with the expected @@ -1515,12 +1662,13 @@ async def test_count_text_tokens_flattened_async(): assert len(call.mock_calls) _, args, _ = call.mock_calls[0] arg = args[0].model - mock_val = 'model_value' + mock_val = "model_value" assert arg == mock_val arg = args[0].prompt - mock_val = text_service.TextPrompt(text='text_value') + mock_val = text_service.TextPrompt(text="text_value") assert arg == mock_val + @pytest.mark.asyncio async def test_count_text_tokens_flattened_error_async(): client = TextServiceAsyncClient( @@ -1532,15 +1680,18 @@ async def test_count_text_tokens_flattened_error_async(): with pytest.raises(ValueError): await client.count_text_tokens( text_service.CountTextTokensRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), ) -@pytest.mark.parametrize("request_type", [ - text_service.GenerateTextRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + text_service.GenerateTextRequest, + dict, + ], +) def test_generate_text_rest(request_type): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -1548,14 +1699,13 @@ def test_generate_text_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = text_service.GenerateTextResponse( - ) + return_value = text_service.GenerateTextResponse() # Wrap the value into a proper Response obj response_value = Response() @@ -1563,7 +1713,7 @@ def test_generate_text_rest(request_type): pb_return_value = text_service.GenerateTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.generate_text(request) @@ -1571,58 +1721,66 @@ def test_generate_text_rest(request_type): assert isinstance(response, text_service.GenerateTextResponse) -def test_generate_text_rest_required_fields(request_type=text_service.GenerateTextRequest): +def test_generate_text_rest_required_fields( + request_type=text_service.GenerateTextRequest, +): transport_class = transports.TextServiceRestTransport request_init = {} request_init["model"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_text._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["model"] = 'model_value' + jsonified_request["model"] = "model_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).generate_text._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_text._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' + assert jsonified_request["model"] == "model_value" client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = text_service.GenerateTextResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -1631,39 +1789,56 @@ def test_generate_text_rest_required_fields(request_type=text_service.GenerateTe pb_return_value = text_service.GenerateTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.generate_text(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_generate_text_rest_unset_required_fields(): - transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.generate_text._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_generate_text_rest_interceptors(null_interceptor): transport = transports.TextServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) client = TextServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.TextServiceRestInterceptor, "post_generate_text") as post, \ - mock.patch.object(transports.TextServiceRestInterceptor, "pre_generate_text") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_generate_text" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_generate_text" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = text_service.GenerateTextRequest.pb(text_service.GenerateTextRequest()) + pb_message = text_service.GenerateTextRequest.pb( + text_service.GenerateTextRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -1674,34 +1849,46 @@ def test_generate_text_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = text_service.GenerateTextResponse.to_json(text_service.GenerateTextResponse()) + req.return_value._content = text_service.GenerateTextResponse.to_json( + text_service.GenerateTextResponse() + ) request = text_service.GenerateTextRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = text_service.GenerateTextResponse() - client.generate_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.generate_text( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_generate_text_rest_bad_request(transport: str = 'rest', request_type=text_service.GenerateTextRequest): +def test_generate_text_rest_bad_request( + transport: str = "rest", request_type=text_service.GenerateTextRequest +): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -1717,17 +1904,17 @@ def test_generate_text_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = text_service.GenerateTextResponse() # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} + sample_request = {"model": "models/sample1"} # get truthy value for each flattened field mock_args = dict( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), temperature=0.1198, candidate_count=1573, max_output_tokens=1865, @@ -1741,7 +1928,7 @@ def test_generate_text_rest_flattened(): response_value.status_code = 200 pb_return_value = text_service.GenerateTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.generate_text(**mock_args) @@ -1750,10 +1937,12 @@ def test_generate_text_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{model=models/*}:generateText" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{model=models/*}:generateText" % client.transport._host, args[1] + ) -def test_generate_text_rest_flattened_error(transport: str = 'rest'): +def test_generate_text_rest_flattened_error(transport: str = "rest"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -1764,8 +1953,8 @@ def test_generate_text_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.generate_text( text_service.GenerateTextRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), temperature=0.1198, candidate_count=1573, max_output_tokens=1865, @@ -1776,15 +1965,17 @@ def test_generate_text_rest_flattened_error(transport: str = 'rest'): def test_generate_text_rest_error(): client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - text_service.EmbedTextRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + text_service.EmbedTextRequest, + dict, + ], +) def test_embed_text_rest(request_type): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -1792,14 +1983,13 @@ def test_embed_text_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = text_service.EmbedTextResponse( - ) + return_value = text_service.EmbedTextResponse() # Wrap the value into a proper Response obj response_value = Response() @@ -1807,7 +1997,7 @@ def test_embed_text_rest(request_type): pb_return_value = text_service.EmbedTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.embed_text(request) @@ -1823,54 +2013,60 @@ def test_embed_text_rest_required_fields(request_type=text_service.EmbedTextRequ request_init["text"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).embed_text._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["model"] = 'model_value' - jsonified_request["text"] = 'text_value' + jsonified_request["model"] = "model_value" + jsonified_request["text"] = "text_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).embed_text._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).embed_text._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' + assert jsonified_request["model"] == "model_value" assert "text" in jsonified_request - assert jsonified_request["text"] == 'text_value' + assert jsonified_request["text"] == "text_value" client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = text_service.EmbedTextResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -1879,36 +2075,51 @@ def test_embed_text_rest_required_fields(request_type=text_service.EmbedTextRequ pb_return_value = text_service.EmbedTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.embed_text(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_embed_text_rest_unset_required_fields(): - transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.embed_text._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "text", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "text", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_embed_text_rest_interceptors(null_interceptor): transport = transports.TextServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) client = TextServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.TextServiceRestInterceptor, "post_embed_text") as post, \ - mock.patch.object(transports.TextServiceRestInterceptor, "pre_embed_text") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_embed_text" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_embed_text" + ) as pre: pre.assert_not_called() post.assert_not_called() pb_message = text_service.EmbedTextRequest.pb(text_service.EmbedTextRequest()) @@ -1922,34 +2133,46 @@ def test_embed_text_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = text_service.EmbedTextResponse.to_json(text_service.EmbedTextResponse()) + req.return_value._content = text_service.EmbedTextResponse.to_json( + text_service.EmbedTextResponse() + ) request = text_service.EmbedTextRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = text_service.EmbedTextResponse() - client.embed_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.embed_text( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_embed_text_rest_bad_request(transport: str = 'rest', request_type=text_service.EmbedTextRequest): +def test_embed_text_rest_bad_request( + transport: str = "rest", request_type=text_service.EmbedTextRequest +): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -1965,17 +2188,17 @@ def test_embed_text_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = text_service.EmbedTextResponse() # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} + sample_request = {"model": "models/sample1"} # get truthy value for each flattened field mock_args = dict( - model='model_value', - text='text_value', + model="model_value", + text="text_value", ) mock_args.update(sample_request) @@ -1984,7 +2207,7 @@ def test_embed_text_rest_flattened(): response_value.status_code = 200 pb_return_value = text_service.EmbedTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.embed_text(**mock_args) @@ -1993,10 +2216,12 @@ def test_embed_text_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{model=models/*}:embedText" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{model=models/*}:embedText" % client.transport._host, args[1] + ) -def test_embed_text_rest_flattened_error(transport: str = 'rest'): +def test_embed_text_rest_flattened_error(transport: str = "rest"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2007,22 +2232,24 @@ def test_embed_text_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.embed_text( text_service.EmbedTextRequest(), - model='model_value', - text='text_value', + model="model_value", + text="text_value", ) def test_embed_text_rest_error(): client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - text_service.BatchEmbedTextRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + text_service.BatchEmbedTextRequest, + dict, + ], +) def test_batch_embed_text_rest(request_type): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2030,14 +2257,13 @@ def test_batch_embed_text_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. - return_value = text_service.BatchEmbedTextResponse( - ) + return_value = text_service.BatchEmbedTextResponse() # Wrap the value into a proper Response obj response_value = Response() @@ -2045,7 +2271,7 @@ def test_batch_embed_text_rest(request_type): pb_return_value = text_service.BatchEmbedTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.batch_embed_text(request) @@ -2053,7 +2279,9 @@ def test_batch_embed_text_rest(request_type): assert isinstance(response, text_service.BatchEmbedTextResponse) -def test_batch_embed_text_rest_required_fields(request_type=text_service.BatchEmbedTextRequest): +def test_batch_embed_text_rest_required_fields( + request_type=text_service.BatchEmbedTextRequest, +): transport_class = transports.TextServiceRestTransport request_init = {} @@ -2061,54 +2289,60 @@ def test_batch_embed_text_rest_required_fields(request_type=text_service.BatchEm request_init["texts"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).batch_embed_text._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_embed_text._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["model"] = 'model_value' - jsonified_request["texts"] = 'texts_value' + jsonified_request["model"] = "model_value" + jsonified_request["texts"] = "texts_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).batch_embed_text._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_embed_text._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' + assert jsonified_request["model"] == "model_value" assert "texts" in jsonified_request - assert jsonified_request["texts"] == 'texts_value' + assert jsonified_request["texts"] == "texts_value" client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = text_service.BatchEmbedTextResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -2117,39 +2351,56 @@ def test_batch_embed_text_rest_required_fields(request_type=text_service.BatchEm pb_return_value = text_service.BatchEmbedTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.batch_embed_text(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_batch_embed_text_rest_unset_required_fields(): - transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.batch_embed_text._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "texts", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "texts", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_batch_embed_text_rest_interceptors(null_interceptor): transport = transports.TextServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) client = TextServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.TextServiceRestInterceptor, "post_batch_embed_text") as post, \ - mock.patch.object(transports.TextServiceRestInterceptor, "pre_batch_embed_text") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_batch_embed_text" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_batch_embed_text" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = text_service.BatchEmbedTextRequest.pb(text_service.BatchEmbedTextRequest()) + pb_message = text_service.BatchEmbedTextRequest.pb( + text_service.BatchEmbedTextRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -2160,34 +2411,46 @@ def test_batch_embed_text_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = text_service.BatchEmbedTextResponse.to_json(text_service.BatchEmbedTextResponse()) + req.return_value._content = text_service.BatchEmbedTextResponse.to_json( + text_service.BatchEmbedTextResponse() + ) request = text_service.BatchEmbedTextRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = text_service.BatchEmbedTextResponse() - client.batch_embed_text(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.batch_embed_text( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_batch_embed_text_rest_bad_request(transport: str = 'rest', request_type=text_service.BatchEmbedTextRequest): +def test_batch_embed_text_rest_bad_request( + transport: str = "rest", request_type=text_service.BatchEmbedTextRequest +): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -2203,17 +2466,17 @@ def test_batch_embed_text_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = text_service.BatchEmbedTextResponse() # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} + sample_request = {"model": "models/sample1"} # get truthy value for each flattened field mock_args = dict( - model='model_value', - texts=['texts_value'], + model="model_value", + texts=["texts_value"], ) mock_args.update(sample_request) @@ -2222,7 +2485,7 @@ def test_batch_embed_text_rest_flattened(): response_value.status_code = 200 pb_return_value = text_service.BatchEmbedTextResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.batch_embed_text(**mock_args) @@ -2231,10 +2494,13 @@ def test_batch_embed_text_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{model=models/*}:batchEmbedText" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{model=models/*}:batchEmbedText" % client.transport._host, + args[1], + ) -def test_batch_embed_text_rest_flattened_error(transport: str = 'rest'): +def test_batch_embed_text_rest_flattened_error(transport: str = "rest"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2245,22 +2511,24 @@ def test_batch_embed_text_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.batch_embed_text( text_service.BatchEmbedTextRequest(), - model='model_value', - texts=['texts_value'], + model="model_value", + texts=["texts_value"], ) def test_batch_embed_text_rest_error(): client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) -@pytest.mark.parametrize("request_type", [ - text_service.CountTextTokensRequest, - dict, -]) +@pytest.mark.parametrize( + "request_type", + [ + text_service.CountTextTokensRequest, + dict, + ], +) def test_count_text_tokens_rest(request_type): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), @@ -2268,14 +2536,14 @@ def test_count_text_tokens_rest(request_type): ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = text_service.CountTextTokensResponse( - token_count=1193, + token_count=1193, ) # Wrap the value into a proper Response obj @@ -2284,7 +2552,7 @@ def test_count_text_tokens_rest(request_type): pb_return_value = text_service.CountTextTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.count_text_tokens(request) @@ -2293,58 +2561,66 @@ def test_count_text_tokens_rest(request_type): assert response.token_count == 1193 -def test_count_text_tokens_rest_required_fields(request_type=text_service.CountTextTokensRequest): +def test_count_text_tokens_rest_required_fields( + request_type=text_service.CountTextTokensRequest, +): transport_class = transports.TextServiceRestTransport request_init = {} request_init["model"] = "" request = request_type(**request_init) pb_request = request_type.pb(request) - jsonified_request = json.loads(json_format.MessageToJson( - pb_request, - including_default_value_fields=False, - use_integers_for_enums=False - )) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) # verify fields with default values are dropped - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_text_tokens._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_text_tokens._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with default values are now present - jsonified_request["model"] = 'model_value' + jsonified_request["model"] = "model_value" - unset_fields = transport_class(credentials=ga_credentials.AnonymousCredentials()).count_text_tokens._get_unset_required_fields(jsonified_request) + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_text_tokens._get_unset_required_fields(jsonified_request) jsonified_request.update(unset_fields) # verify required fields with non-default values are left alone assert "model" in jsonified_request - assert jsonified_request["model"] == 'model_value' + assert jsonified_request["model"] == "model_value" client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - transport='rest', + transport="rest", ) request = request_type(**request_init) # Designate an appropriate value for the returned response. return_value = text_service.CountTextTokensResponse() # Mock the http request call within the method and fake a response. - with mock.patch.object(Session, 'request') as req: + with mock.patch.object(Session, "request") as req: # We need to mock transcode() because providing default values # for required fields will fail the real version if the http_options # expect actual values for those fields. - with mock.patch.object(path_template, 'transcode') as transcode: + with mock.patch.object(path_template, "transcode") as transcode: # A uri without fields and an empty body will force all the # request fields to show up in the query_params. pb_request = request_type.pb(request) transcode_result = { - 'uri': 'v1/sample_method', - 'method': "post", - 'query_params': pb_request, + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, } - transcode_result['body'] = pb_request + transcode_result["body"] = pb_request transcode.return_value = transcode_result response_value = Response() @@ -2353,39 +2629,56 @@ def test_count_text_tokens_rest_required_fields(request_type=text_service.CountT pb_return_value = text_service.CountTextTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.count_text_tokens(request) - expected_params = [ - ('$alt', 'json;enum-encoding=int') - ] - actual_params = req.call_args.kwargs['params'] + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] assert expected_params == actual_params def test_count_text_tokens_rest_unset_required_fields(): - transport = transports.TextServiceRestTransport(credentials=ga_credentials.AnonymousCredentials) + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) unset_fields = transport.count_text_tokens._get_unset_required_fields({}) - assert set(unset_fields) == (set(()) & set(("model", "prompt", ))) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) @pytest.mark.parametrize("null_interceptor", [True, False]) def test_count_text_tokens_rest_interceptors(null_interceptor): transport = transports.TextServiceRestTransport( credentials=ga_credentials.AnonymousCredentials(), - interceptor=None if null_interceptor else transports.TextServiceRestInterceptor(), - ) + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) client = TextServiceClient(transport=transport) - with mock.patch.object(type(client.transport._session), "request") as req, \ - mock.patch.object(path_template, "transcode") as transcode, \ - mock.patch.object(transports.TextServiceRestInterceptor, "post_count_text_tokens") as post, \ - mock.patch.object(transports.TextServiceRestInterceptor, "pre_count_text_tokens") as pre: + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_count_text_tokens" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_count_text_tokens" + ) as pre: pre.assert_not_called() post.assert_not_called() - pb_message = text_service.CountTextTokensRequest.pb(text_service.CountTextTokensRequest()) + pb_message = text_service.CountTextTokensRequest.pb( + text_service.CountTextTokensRequest() + ) transcode.return_value = { "method": "post", "uri": "my_uri", @@ -2396,34 +2689,46 @@ def test_count_text_tokens_rest_interceptors(null_interceptor): req.return_value = Response() req.return_value.status_code = 200 req.return_value.request = PreparedRequest() - req.return_value._content = text_service.CountTextTokensResponse.to_json(text_service.CountTextTokensResponse()) + req.return_value._content = text_service.CountTextTokensResponse.to_json( + text_service.CountTextTokensResponse() + ) request = text_service.CountTextTokensRequest() - metadata =[ + metadata = [ ("key", "val"), ("cephalopod", "squid"), ] pre.return_value = request, metadata post.return_value = text_service.CountTextTokensResponse() - client.count_text_tokens(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + client.count_text_tokens( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) pre.assert_called_once() post.assert_called_once() -def test_count_text_tokens_rest_bad_request(transport: str = 'rest', request_type=text_service.CountTextTokensRequest): +def test_count_text_tokens_rest_bad_request( + transport: str = "rest", request_type=text_service.CountTextTokensRequest +): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, ) # send a request that will satisfy transcoding - request_init = {'model': 'models/sample1'} + request_init = {"model": "models/sample1"} request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. - with mock.patch.object(Session, 'request') as req, pytest.raises(core_exceptions.BadRequest): + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): # Wrap the value into a proper Response obj response_value = Response() response_value.status_code = 400 @@ -2439,17 +2744,17 @@ def test_count_text_tokens_rest_flattened(): ) # Mock the http request call within the method and fake a response. - with mock.patch.object(type(client.transport._session), 'request') as req: + with mock.patch.object(type(client.transport._session), "request") as req: # Designate an appropriate value for the returned response. return_value = text_service.CountTextTokensResponse() # get arguments that satisfy an http rule for this method - sample_request = {'model': 'models/sample1'} + sample_request = {"model": "models/sample1"} # get truthy value for each flattened field mock_args = dict( - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), ) mock_args.update(sample_request) @@ -2458,7 +2763,7 @@ def test_count_text_tokens_rest_flattened(): response_value.status_code = 200 pb_return_value = text_service.CountTextTokensResponse.pb(return_value) json_return_value = json_format.MessageToJson(pb_return_value) - response_value._content = json_return_value.encode('UTF-8') + response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value client.count_text_tokens(**mock_args) @@ -2467,10 +2772,13 @@ def test_count_text_tokens_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] - assert path_template.validate("%s/v1beta3/{model=models/*}:countTextTokens" % client.transport._host, args[1]) + assert path_template.validate( + "%s/v1beta3/{model=models/*}:countTextTokens" % client.transport._host, + args[1], + ) -def test_count_text_tokens_rest_flattened_error(transport: str = 'rest'): +def test_count_text_tokens_rest_flattened_error(transport: str = "rest"): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -2481,15 +2789,14 @@ def test_count_text_tokens_rest_flattened_error(transport: str = 'rest'): with pytest.raises(ValueError): client.count_text_tokens( text_service.CountTextTokensRequest(), - model='model_value', - prompt=text_service.TextPrompt(text='text_value'), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), ) def test_count_text_tokens_rest_error(): client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport='rest' + credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) @@ -2531,8 +2838,7 @@ def test_credentials_transport_error(): options.api_key = "api_key" with pytest.raises(ValueError): client = TextServiceClient( - client_options=options, - credentials=ga_credentials.AnonymousCredentials() + client_options=options, credentials=ga_credentials.AnonymousCredentials() ) # It is an error to provide scopes and a transport instance. @@ -2554,6 +2860,7 @@ def test_transport_instance(): client = TextServiceClient(transport=transport) assert client.transport is transport + def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.TextServiceGrpcTransport( @@ -2568,28 +2875,37 @@ def test_transport_get_channel(): channel = transport.grpc_channel assert channel -@pytest.mark.parametrize("transport_class", [ - transports.TextServiceGrpcTransport, - transports.TextServiceGrpcAsyncIOTransport, - transports.TextServiceRestTransport, -]) + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + transports.TextServiceRestTransport, + ], +) def test_transport_adc(transport_class): # Test default credentials are used if not provided. - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class() adc.assert_called_once() -@pytest.mark.parametrize("transport_name", [ - "grpc", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "rest", + ], +) def test_transport_kind(transport_name): transport = TextServiceClient.get_transport_class(transport_name)( credentials=ga_credentials.AnonymousCredentials(), ) assert transport.kind == transport_name + def test_transport_grpc_default(): # A client should use the gRPC transport by default. client = TextServiceClient( @@ -2600,18 +2916,21 @@ def test_transport_grpc_default(): transports.TextServiceGrpcTransport, ) + def test_text_service_base_transport_error(): # Passing both a credentials object and credentials_file should raise an error with pytest.raises(core_exceptions.DuplicateCredentialArgs): transport = transports.TextServiceTransport( credentials=ga_credentials.AnonymousCredentials(), - credentials_file="credentials.json" + credentials_file="credentials.json", ) def test_text_service_base_transport(): # Instantiate the base transport. - with mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport.__init__') as Transport: + with mock.patch( + "google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport.__init__" + ) as Transport: Transport.return_value = None transport = transports.TextServiceTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -2620,10 +2939,10 @@ def test_text_service_base_transport(): # Every method on the transport should just blindly # raise NotImplementedError. methods = ( - 'generate_text', - 'embed_text', - 'batch_embed_text', - 'count_text_tokens', + "generate_text", + "embed_text", + "batch_embed_text", + "count_text_tokens", ) for method in methods: with pytest.raises(NotImplementedError): @@ -2634,7 +2953,7 @@ def test_text_service_base_transport(): # Catch all for all remaining methods and properties remainder = [ - 'kind', + "kind", ] for r in remainder: with pytest.raises(NotImplementedError): @@ -2643,24 +2962,30 @@ def test_text_service_base_transport(): def test_text_service_base_transport_with_credentials_file(): # Instantiate the base transport with a credentials file - with mock.patch.object(google.auth, 'load_credentials_from_file', autospec=True) as load_creds, mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.TextServiceTransport( credentials_file="credentials.json", quota_project_id="octopus", ) - load_creds.assert_called_once_with("credentials.json", + load_creds.assert_called_once_with( + "credentials.json", scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id="octopus", ) def test_text_service_base_transport_with_adc(): # Test the default credentials are used if credentials and credentials_file are None. - with mock.patch.object(google.auth, 'default', autospec=True) as adc, mock.patch('google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport._prep_wrapped_messages') as Transport: + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1beta3.services.text_service.transports.TextServiceTransport._prep_wrapped_messages" + ) as Transport: Transport.return_value = None adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport = transports.TextServiceTransport() @@ -2669,13 +2994,12 @@ def test_text_service_base_transport_with_adc(): def test_text_service_auth_adc(): # If no credentials are provided, we should use ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) TextServiceClient() adc.assert_called_once_with( scopes=None, - default_scopes=( -), + default_scopes=(), quota_project_id=None, ) @@ -2690,7 +3014,7 @@ def test_text_service_auth_adc(): def test_text_service_transport_auth_adc(transport_class): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: adc.return_value = (ga_credentials.AnonymousCredentials(), None) transport_class(quota_project_id="octopus", scopes=["1", "2"]) adc.assert_called_once_with( @@ -2709,47 +3033,45 @@ def test_text_service_transport_auth_adc(transport_class): ], ) def test_text_service_transport_auth_gdch_credentials(transport_class): - host = 'https://language.com' - api_audience_tests = [None, 'https://language2.com'] - api_audience_expect = [host, 'https://language2.com'] + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] for t, e in zip(api_audience_tests, api_audience_expect): - with mock.patch.object(google.auth, 'default', autospec=True) as adc: + with mock.patch.object(google.auth, "default", autospec=True) as adc: gdch_mock = mock.MagicMock() - type(gdch_mock).with_gdch_audience = mock.PropertyMock(return_value=gdch_mock) + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) adc.return_value = (gdch_mock, None) transport_class(host=host, api_audience=t) - gdch_mock.with_gdch_audience.assert_called_once_with( - e - ) + gdch_mock.with_gdch_audience.assert_called_once_with(e) @pytest.mark.parametrize( "transport_class,grpc_helpers", [ (transports.TextServiceGrpcTransport, grpc_helpers), - (transports.TextServiceGrpcAsyncIOTransport, grpc_helpers_async) + (transports.TextServiceGrpcAsyncIOTransport, grpc_helpers_async), ], ) def test_text_service_transport_create_channel(transport_class, grpc_helpers): # If credentials and host are not provided, the transport class should use # ADC credentials. - with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch.object( + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( grpc_helpers, "create_channel", autospec=True ) as create_channel: creds = ga_credentials.AnonymousCredentials() adc.return_value = (creds, None) - transport_class( - quota_project_id="octopus", - scopes=["1", "2"] - ) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) create_channel.assert_called_with( "generativelanguage.googleapis.com:443", credentials=creds, credentials_file=None, quota_project_id="octopus", - default_scopes=( -), + default_scopes=(), scopes=["1", "2"], default_host="generativelanguage.googleapis.com", ssl_credentials=None, @@ -2760,10 +3082,11 @@ def test_text_service_transport_create_channel(transport_class, grpc_helpers): ) -@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) -def test_text_service_grpc_transport_client_cert_source_for_mtls( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport], +) +def test_text_service_grpc_transport_client_cert_source_for_mtls(transport_class): cred = ga_credentials.AnonymousCredentials() # Check ssl_channel_credentials is used if provided. @@ -2772,7 +3095,7 @@ def test_text_service_grpc_transport_client_cert_source_for_mtls( transport_class( host="squid.clam.whelk", credentials=cred, - ssl_channel_credentials=mock_ssl_channel_creds + ssl_channel_credentials=mock_ssl_channel_creds, ) mock_create_channel.assert_called_once_with( "squid.clam.whelk:443", @@ -2793,61 +3116,77 @@ def test_text_service_grpc_transport_client_cert_source_for_mtls( with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: transport_class( credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + client_cert_source_for_mtls=client_cert_source_callback, ) expected_cert, expected_key = client_cert_source_callback() mock_ssl_cred.assert_called_once_with( - certificate_chain=expected_cert, - private_key=expected_key + certificate_chain=expected_cert, private_key=expected_key ) + def test_text_service_http_transport_client_cert_source_for_mtls(): cred = ga_credentials.AnonymousCredentials() - with mock.patch("google.auth.transport.requests.AuthorizedSession.configure_mtls_channel") as mock_configure_mtls_channel: - transports.TextServiceRestTransport ( - credentials=cred, - client_cert_source_for_mtls=client_cert_source_callback + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TextServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback ) mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_text_service_host_no_port(transport_name): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com'), - transport=transport_name, + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:443' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com' + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" ) -@pytest.mark.parametrize("transport_name", [ - "grpc", - "grpc_asyncio", - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) def test_text_service_host_with_port(transport_name): client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), - client_options=client_options.ClientOptions(api_endpoint='generativelanguage.googleapis.com:8000'), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), transport=transport_name, ) assert client.transport._host == ( - 'generativelanguage.googleapis.com:8000' - if transport_name in ['grpc', 'grpc_asyncio'] - else 'https://generativelanguage.googleapis.com:8000' + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" ) -@pytest.mark.parametrize("transport_name", [ - "rest", -]) + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) def test_text_service_client_transport_session_collision(transport_name): creds1 = ga_credentials.AnonymousCredentials() creds2 = ga_credentials.AnonymousCredentials() @@ -2871,8 +3210,10 @@ def test_text_service_client_transport_session_collision(transport_name): session1 = client1.transport.count_text_tokens._session session2 = client2.transport.count_text_tokens._session assert session1 != session2 + + def test_text_service_grpc_transport_channel(): - channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.TextServiceGrpcTransport( @@ -2885,7 +3226,7 @@ def test_text_service_grpc_transport_channel(): def test_text_service_grpc_asyncio_transport_channel(): - channel = aio.secure_channel('http://localhost/', grpc.local_channel_credentials()) + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.TextServiceGrpcAsyncIOTransport( @@ -2899,12 +3240,17 @@ def test_text_service_grpc_asyncio_transport_channel(): # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) -def test_text_service_transport_channel_mtls_with_client_cert_source( - transport_class -): - with mock.patch("grpc.ssl_channel_credentials", autospec=True) as grpc_ssl_channel_cred: - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: +@pytest.mark.parametrize( + "transport_class", + [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport], +) +def test_text_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2913,7 +3259,7 @@ def test_text_service_transport_channel_mtls_with_client_cert_source( cred = ga_credentials.AnonymousCredentials() with pytest.warns(DeprecationWarning): - with mock.patch.object(google.auth, 'default') as adc: + with mock.patch.object(google.auth, "default") as adc: adc.return_value = (cred, None) transport = transport_class( host="squid.clam.whelk", @@ -2943,17 +3289,20 @@ def test_text_service_transport_channel_mtls_with_client_cert_source( # Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are # removed from grpc/grpc_asyncio transport constructor. -@pytest.mark.parametrize("transport_class", [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport]) -def test_text_service_transport_channel_mtls_with_adc( - transport_class -): +@pytest.mark.parametrize( + "transport_class", + [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport], +) +def test_text_service_transport_channel_mtls_with_adc(transport_class): mock_ssl_cred = mock.Mock() with mock.patch.multiple( "google.auth.transport.grpc.SslCredentials", __init__=mock.Mock(return_value=None), ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): - with mock.patch.object(transport_class, "create_channel") as grpc_create_channel: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel mock_cred = mock.Mock() @@ -2983,7 +3332,9 @@ def test_text_service_transport_channel_mtls_with_adc( def test_model_path(): model = "squid" - expected = "models/{model}".format(model=model, ) + expected = "models/{model}".format( + model=model, + ) actual = TextServiceClient.model_path(model) assert expected == actual @@ -2998,9 +3349,12 @@ def test_parse_model_path(): actual = TextServiceClient.parse_model_path(path) assert expected == actual + def test_common_billing_account_path(): billing_account = "whelk" - expected = "billingAccounts/{billing_account}".format(billing_account=billing_account, ) + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) actual = TextServiceClient.common_billing_account_path(billing_account) assert expected == actual @@ -3015,9 +3369,12 @@ def test_parse_common_billing_account_path(): actual = TextServiceClient.parse_common_billing_account_path(path) assert expected == actual + def test_common_folder_path(): folder = "oyster" - expected = "folders/{folder}".format(folder=folder, ) + expected = "folders/{folder}".format( + folder=folder, + ) actual = TextServiceClient.common_folder_path(folder) assert expected == actual @@ -3032,9 +3389,12 @@ def test_parse_common_folder_path(): actual = TextServiceClient.parse_common_folder_path(path) assert expected == actual + def test_common_organization_path(): organization = "cuttlefish" - expected = "organizations/{organization}".format(organization=organization, ) + expected = "organizations/{organization}".format( + organization=organization, + ) actual = TextServiceClient.common_organization_path(organization) assert expected == actual @@ -3049,9 +3409,12 @@ def test_parse_common_organization_path(): actual = TextServiceClient.parse_common_organization_path(path) assert expected == actual + def test_common_project_path(): project = "winkle" - expected = "projects/{project}".format(project=project, ) + expected = "projects/{project}".format( + project=project, + ) actual = TextServiceClient.common_project_path(project) assert expected == actual @@ -3066,10 +3429,14 @@ def test_parse_common_project_path(): actual = TextServiceClient.parse_common_project_path(path) assert expected == actual + def test_common_location_path(): project = "scallop" location = "abalone" - expected = "projects/{project}/locations/{location}".format(project=project, location=location, ) + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) actual = TextServiceClient.common_location_path(project, location) assert expected == actual @@ -3089,14 +3456,18 @@ def test_parse_common_location_path(): def test_client_with_default_client_info(): client_info = gapic_v1.client_info.ClientInfo() - with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.TextServiceTransport, "_prep_wrapped_messages" + ) as prep: client = TextServiceClient( credentials=ga_credentials.AnonymousCredentials(), client_info=client_info, ) prep.assert_called_once_with(client_info) - with mock.patch.object(transports.TextServiceTransport, '_prep_wrapped_messages') as prep: + with mock.patch.object( + transports.TextServiceTransport, "_prep_wrapped_messages" + ) as prep: transport_class = TextServiceClient.get_transport_class() transport = transport_class( credentials=ga_credentials.AnonymousCredentials(), @@ -3104,13 +3475,16 @@ def test_client_with_default_client_info(): ) prep.assert_called_once_with(client_info) + @pytest.mark.asyncio async def test_transport_close_async(): client = TextServiceAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio", ) - with mock.patch.object(type(getattr(client.transport, "grpc_channel")), "close") as close: + with mock.patch.object( + type(getattr(client.transport, "grpc_channel")), "close" + ) as close: async with client: close.assert_not_called() close.assert_called_once() @@ -3124,23 +3498,24 @@ def test_transport_close(): for transport, close_name in transports.items(): client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) - with mock.patch.object(type(getattr(client.transport, close_name)), "close") as close: + with mock.patch.object( + type(getattr(client.transport, close_name)), "close" + ) as close: with client: close.assert_not_called() close.assert_called_once() + def test_client_ctx(): transports = [ - 'rest', - 'grpc', + "rest", + "grpc", ] for transport in transports: client = TextServiceClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport + credentials=ga_credentials.AnonymousCredentials(), transport=transport ) # Test client calls underlying transport. with mock.patch.object(type(client.transport), "close") as close: @@ -3149,10 +3524,14 @@ def test_client_ctx(): pass close.assert_called() -@pytest.mark.parametrize("client_class,transport_class", [ - (TextServiceClient, transports.TextServiceGrpcTransport), - (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport), -]) + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (TextServiceClient, transports.TextServiceGrpcTransport), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport), + ], +) def test_api_key_credentials(client_class, transport_class): with mock.patch.object( google.auth._default, "get_api_key_credentials", create=True From 4ac03586b552448eda5d62e2109fd9dd18ef83d9 Mon Sep 17 00:00:00 2001 From: Anthonios Partheniou Date: Wed, 20 Sep 2023 00:38:10 +0000 Subject: [PATCH 3/3] remove whitespace --- .../permission_service/async_client.py | 18 +++++++++--------- .../services/permission_service/client.py | 18 +++++++++--------- .../types/permission.py | 6 +++--- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py index 9b9faceba44c..87274cad5339 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/async_client.py @@ -305,11 +305,11 @@ async def sample_create_permission(): role is a superset of the previous role's permitted operations: - - reader can use the resource (e.g. + - reader can use the resource (e.g. tuned model) for inference - - writer has reader's permissions and + - writer has reader's permissions and additionally can edit and share - - owner has writer's permissions and + - owner has writer's permissions and additionally can delete """ @@ -432,11 +432,11 @@ async def sample_get_permission(): role is a superset of the previous role's permitted operations: - - reader can use the resource (e.g. + - reader can use the resource (e.g. tuned model) for inference - - writer has reader's permissions and + - writer has reader's permissions and additionally can edit and share - - owner has writer's permissions and + - owner has writer's permissions and additionally can delete """ @@ -682,11 +682,11 @@ async def sample_update_permission(): role is a superset of the previous role's permitted operations: - - reader can use the resource (e.g. + - reader can use the resource (e.g. tuned model) for inference - - writer has reader's permissions and + - writer has reader's permissions and additionally can edit and share - - owner has writer's permissions and + - owner has writer's permissions and additionally can delete """ diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/client.py index 78bbe681b0cc..9afdb7375e5e 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/services/permission_service/client.py @@ -542,11 +542,11 @@ def sample_create_permission(): role is a superset of the previous role's permitted operations: - - reader can use the resource (e.g. + - reader can use the resource (e.g. tuned model) for inference - - writer has reader's permissions and + - writer has reader's permissions and additionally can edit and share - - owner has writer's permissions and + - owner has writer's permissions and additionally can delete """ @@ -669,11 +669,11 @@ def sample_get_permission(): role is a superset of the previous role's permitted operations: - - reader can use the resource (e.g. + - reader can use the resource (e.g. tuned model) for inference - - writer has reader's permissions and + - writer has reader's permissions and additionally can edit and share - - owner has writer's permissions and + - owner has writer's permissions and additionally can delete """ @@ -919,11 +919,11 @@ def sample_update_permission(): role is a superset of the previous role's permitted operations: - - reader can use the resource (e.g. + - reader can use the resource (e.g. tuned model) for inference - - writer has reader's permissions and + - writer has reader's permissions and additionally can edit and share - - owner has writer's permissions and + - owner has writer's permissions and additionally can delete """ diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission.py index 115ca22e8bef..09af2311c4ed 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/types/permission.py @@ -41,10 +41,10 @@ class Permission(proto.Message): There are three concentric roles. Each role is a superset of the previous role's permitted operations: - - reader can use the resource (e.g. tuned model) for inference - - writer has reader's permissions and additionally can edit and + - reader can use the resource (e.g. tuned model) for inference + - writer has reader's permissions and additionally can edit and share - - owner has writer's permissions and additionally can delete + - owner has writer's permissions and additionally can delete .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields