From a5abe12fc8a0402b724f49c4c34f8d68c9179e72 Mon Sep 17 00:00:00 2001 From: arithmetic1728 Date: Tue, 20 Oct 2020 16:05:16 -0700 Subject: [PATCH] fix: expose ssl credentials from transport --- .../%version/%sub/services/%service/transports/grpc.py.j2 | 4 ++++ .../tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 | 2 ++ .../%sub/services/%service/transports/grpc.py.j2 | 4 ++++ .../%sub/services/%service/transports/grpc_asyncio.py.j2 | 4 ++++ .../tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 | 3 +++ 5 files changed, 17 insertions(+) diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 index 6995549378..6438991b57 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/grpc.py.j2 @@ -88,6 +88,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport creation failed for any reason. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -95,6 +97,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) @@ -122,6 +125,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 14687174cd..5dc0f26bab 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -708,6 +708,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}]) @@ -749,6 +750,7 @@ def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_s quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }},]) diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 index 47eaffeb19..e2c68c483d 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2 @@ -96,6 +96,8 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -103,6 +105,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) @@ -130,6 +133,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport): scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 index aae858bf79..6399f1f1cd 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2 @@ -140,6 +140,8 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` and ``credentials_file`` are passed. """ + self._ssl_channel_credentials = ssl_channel_credentials + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -147,6 +149,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): # If a channel was explicitly provided, set it. self._grpc_channel = channel + self._ssl_channel_credentials = None elif api_mtls_endpoint: warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning) @@ -174,6 +177,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport): scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, ) + self._ssl_channel_credentials = ssl_credentials else: host = host if ":" in host else host + ":443" diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index dd4fd637ec..5882e95f78 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -1184,6 +1184,7 @@ def test_{{ service.name|snake_case }}_grpc_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel(): @@ -1196,6 +1197,7 @@ def test_{{ service.name|snake_case }}_grpc_asyncio_transport_channel(): ) assert transport.grpc_channel == channel assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None @pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}]) @@ -1237,6 +1239,7 @@ def test_{{ service.name|snake_case }}_transport_channel_mtls_with_client_cert_s quota_project_id=None, ) assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred @pytest.mark.parametrize("transport_class", [transports.{{ service.grpc_transport_name }}, transports.{{ service.grpc_asyncio_transport_name }}])