11{% extends "_base.py.j2" %}
22
33{% block content %}
4+ import os
45from unittest import mock
56
67import grpc
@@ -11,6 +12,7 @@ import pytest
1112{% filter sort_lines -%}
1213from google import auth
1314from google.auth import credentials
15+ from google.auth.exceptions import MutualTLSChannelError
1416from google.oauth2 import service_account
1517from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.client_name }}
1618from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import transports
@@ -63,6 +65,14 @@ def test_{{ service.client_name|snake_case }}_from_service_account_file():
6365 {% if service .host %} assert client._transport._host == '{{ service.host }}{% if ":" not in service .host %} :443{% endif %} '{% endif %}
6466
6567
68+ def test_{{ service.client_name|snake_case }}_get_transport_class():
69+ transport = {{ service.client_name }}.get_transport_class()
70+ assert transport == transports.{{ service.name }}GrpcTransport
71+
72+ transport = {{ service.client_name }}.get_transport_class("grpc")
73+ assert transport == transports.{{ service.name }}GrpcTransport
74+
75+
6676def test_{{ service.client_name|snake_case }}_client_options():
6777 # Check that if channel is provided we won't create a new one.
6878 with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
@@ -72,58 +82,99 @@ def test_{{ service.client_name|snake_case }}_client_options():
7282 client = {{ service.client_name }}(transport=transport)
7383 gtc.assert_not_called()
7484
75- # Check mTLS is not triggered with empty client options.
76- options = client_options.ClientOptions()
85+ # Check that if channel is provided via str we will create a new one.
7786 with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.{{ service.client_name }}.get_transport_class') as gtc:
78- transport = gtc.return_value = mock.MagicMock()
79- client = {{ service.client_name }}(client_options=options)
80- transport.assert_called_once_with(
81- credentials=None,
82- host=client.DEFAULT_ENDPOINT,
83- )
87+ client = {{ service.client_name }}(transport="grpc")
88+ gtc.assert_called()
8489
85- # Check mTLS is not triggered if api_endpoint is provided but
86- # client_cert_source is None.
90+ # Check the case api_endpoint is provided.
8791 options = client_options.ClientOptions(api_endpoint="squid.clam.whelk")
8892 with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
8993 grpc_transport.return_value = None
9094 client = {{ service.client_name }}(client_options=options)
9195 grpc_transport.assert_called_once_with(
92- api_mtls_endpoint=None ,
96+ api_mtls_endpoint="squid.clam.whelk" ,
9397 client_cert_source=None,
9498 credentials=None,
9599 host="squid.clam.whelk",
96100 )
97101
98- # Check mTLS is triggered if client_cert_source is provided.
99- options = client_options.ClientOptions(
100- client_cert_source=client_cert_source_callback
101- )
102+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
103+ # "Never".
104+ os.environ["GOOGLE_API_USE_MTLS"] = "Never"
102105 with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
103106 grpc_transport.return_value = None
104- client = {{ service.client_name }}(client_options=options )
107+ client = {{ service.client_name }}()
105108 grpc_transport.assert_called_once_with(
106- api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT ,
107- client_cert_source=client_cert_source_callback ,
109+ api_mtls_endpoint=client.DEFAULT_ENDPOINT ,
110+ client_cert_source=None ,
108111 credentials=None,
109112 host=client.DEFAULT_ENDPOINT,
110113 )
111114
112- # Check mTLS is triggered if api_endpoint and client_cert_source are provided.
113- options = client_options.ClientOptions(
114- api_endpoint="squid.clam.whelk",
115- client_cert_source=client_cert_source_callback
116- )
115+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS is
116+ # "Always".
117+ os.environ["GOOGLE_API_USE_MTLS"] = "Always"
118+ with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
119+ grpc_transport.return_value = None
120+ client = {{ service.client_name }}()
121+ grpc_transport.assert_called_once_with(
122+ api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
123+ client_cert_source=None,
124+ credentials=None,
125+ host=client.DEFAULT_MTLS_ENDPOINT,
126+ )
127+
128+ # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
129+ # "Auto", and client_cert_source is provided.
130+ os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
131+ options = client_options.ClientOptions(client_cert_source=client_cert_source_callback)
117132 with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
118133 grpc_transport.return_value = None
119134 client = {{ service.client_name }}(client_options=options)
120135 grpc_transport.assert_called_once_with(
121- api_mtls_endpoint="squid.clam.whelk" ,
136+ api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT ,
122137 client_cert_source=client_cert_source_callback,
123138 credentials=None,
124- host="squid.clam.whelk" ,
139+ host=client.DEFAULT_MTLS_ENDPOINT ,
125140 )
126141
142+ # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
143+ # "Auto", and default_client_cert_source is provided.
144+ os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
145+ with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
146+ with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=True):
147+ grpc_transport.return_value = None
148+ client = {{ service.client_name }}()
149+ grpc_transport.assert_called_once_with(
150+ api_mtls_endpoint=client.DEFAULT_MTLS_ENDPOINT,
151+ client_cert_source=None,
152+ credentials=None,
153+ host=client.DEFAULT_MTLS_ENDPOINT,
154+ )
155+
156+ # Check the case api_endpoint is not provided, GOOGLE_API_USE_MTLS is
157+ # "Auto", but client_cert_source and default_client_cert_source are None.
158+ os.environ["GOOGLE_API_USE_MTLS"] = "Auto"
159+ with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
160+ with mock.patch('google.auth.transport.mtls.has_default_client_cert_source', return_value=False):
161+ grpc_transport.return_value = None
162+ client = {{ service.client_name }}()
163+ grpc_transport.assert_called_once_with(
164+ api_mtls_endpoint=client.DEFAULT_ENDPOINT,
165+ client_cert_source=None,
166+ credentials=None,
167+ host=client.DEFAULT_ENDPOINT,
168+ )
169+
170+ # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS has
171+ # unsupported value.
172+ os.environ["GOOGLE_API_USE_MTLS"] = "Unsupported"
173+ with pytest.raises(MutualTLSChannelError):
174+ client = {{ service.client_name }}()
175+
176+ del os.environ["GOOGLE_API_USE_MTLS"]
177+
127178
128179def test_{{ service.client_name|snake_case }}_client_options_from_dict():
129180 with mock.patch('{{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }}.transports.{{ service.name }}GrpcTransport.__init__') as grpc_transport:
@@ -132,7 +183,7 @@ def test_{{ service.client_name|snake_case }}_client_options_from_dict():
132183 client_options={'api_endpoint': 'squid.clam.whelk'}
133184 )
134185 grpc_transport.assert_called_once_with(
135- api_mtls_endpoint=None ,
186+ api_mtls_endpoint="squid.clam.whelk" ,
136187 client_cert_source=None,
137188 credentials=None,
138189 host="squid.clam.whelk",
@@ -490,12 +541,24 @@ def test_{{ service.name|snake_case }}_auth_adc():
490541 ))
491542
492543
544+ def test_{{ service.name|snake_case }}_transport_auth_adc():
545+ # If credentials and host are not provided, the transport class should use
546+ # ADC credentials.
547+ with mock.patch.object(auth, 'default') as adc:
548+ adc.return_value = (credentials.AnonymousCredentials(), None)
549+ transports.{{ service.name }}GrpcTransport(host="squid.clam.whelk")
550+ adc.assert_called_once_with(scopes=(
551+ {% - for scope in service .oauth_scopes %}
552+ '{{ scope }}',
553+ {% - endfor %}
554+ ))
555+
556+
493557def test_{{ service.name|snake_case }}_host_no_port():
494558 {% with host = (service .host |default ('localhost' , true )).split (':' )[0] -%}
495559 client = {{ service.client_name }}(
496560 credentials=credentials.AnonymousCredentials(),
497561 client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
498- transport='grpc',
499562 )
500563 assert client._transport._host == '{{ host }}:443'
501564 {% endwith %}
@@ -506,7 +569,6 @@ def test_{{ service.name|snake_case }}_host_with_port():
506569 client = {{ service.client_name }}(
507570 credentials=credentials.AnonymousCredentials(),
508571 client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
509- transport='grpc',
510572 )
511573 assert client._transport._host == '{{ host }}:8000'
512574 {% endwith %}
0 commit comments