From 6558b6f65791633a186909ddf7bd23b088cf7b91 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Thu, 10 Dec 2020 13:59:06 -0500 Subject: [PATCH] feat: add 'from_service_account_info' factory to clients Closes #705 --- .../%version/%sub/services/%service/client.py.j2 | 16 ++++++++++++++++ .../%name_%version/%sub/test_%service.py.j2 | 11 +++++++++++ .../%sub/services/%service/async_client.py.j2 | 1 + .../%sub/services/%service/client.py.j2 | 16 ++++++++++++++++ .../%name_%version/%sub/test_%service.py.j2 | 11 +++++++++++ 5 files changed, 55 insertions(+) diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 index 5f12de323c..e3d0016a7a 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/client.py.j2 @@ -97,6 +97,22 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + {@api.name}: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 5dc0f26bab..e86a2f41db 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -61,6 +61,17 @@ def test__get_default_mtls_endpoint(): assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi +def test_{{ service.client_name|snake_case }}_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = {{ service.client_name }}.from_service_account_info(info) + assert client.transport._credentials == creds + + {% if service.host %}assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %} + + def test_{{ service.client_name|snake_case }}_from_service_account_file(): creds = credentials.AnonymousCredentials() with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory: diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 index c6320144ef..6354418f7f 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/async_client.py.j2 @@ -48,6 +48,7 @@ class {{ service.async_client_name }}: parse_common_{{ resource_msg.message_type.resource_type|snake_case }}_path = staticmethod({{ service.client_name }}.parse_common_{{ resource_msg.message_type.resource_type|snake_case }}_path) {% endfor %} + from_service_account_info = {{ service.client_name }}.from_service_account_info from_service_account_file = {{ service.client_name }}.from_service_account_file from_service_account_json = from_service_account_file diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 index 6ec3d5d879..6d703fb2de 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2 @@ -113,6 +113,22 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + {@api.name}: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 5882e95f78..429ccec2d0 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -70,6 +70,17 @@ def test__get_default_mtls_endpoint(): assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi +def test_{{ service.client_name|snake_case }}_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory: + factory.return_value = creds + info = {"valid": True} + client = {{ service.client_name }}.from_service_account_info(info) + assert client.transport._credentials == creds + + {% if service.host %}assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'{% endif %} + + @pytest.mark.parametrize("client_class", [{{ service.client_name }}, {{ service.async_client_name }}]) def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class): creds = credentials.AnonymousCredentials()