Skip to content

Commit 524dbab

Browse files
feat: add mtls feature to rest transport (#731)
* feat: add mtls support to rest transport * update * update * update
1 parent e17c551 commit 524dbab

File tree

5 files changed

+157
-79
lines changed

5 files changed

+157
-79
lines changed

gapic/templates/%namespace/%name_%version/%sub/services/%service/client.py.j2

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -239,21 +239,15 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
239239
# Create SSL credentials for mutual TLS if needed.
240240
use_client_cert = bool(util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")))
241241

242-
ssl_credentials = None
242+
client_cert_source_func = None
243243
is_mtls = False
244244
if use_client_cert:
245245
if client_options.client_cert_source:
246-
import grpc # type: ignore
247-
248-
cert, key = client_options.client_cert_source()
249-
ssl_credentials = grpc.ssl_channel_credentials(
250-
certificate_chain=cert, private_key=key
251-
)
252246
is_mtls = True
247+
client_cert_source_func = client_options.client_cert_source
253248
else:
254-
creds = SslCredentials()
255-
is_mtls = creds.is_mtls
256-
ssl_credentials = creds.ssl_credentials if is_mtls else None
249+
is_mtls = mtls.has_default_client_cert_source()
250+
client_cert_source_func = mtls.default_client_cert_source() if is_mtls else None
257251

258252
# Figure out which api endpoint to use.
259253
if client_options.api_endpoint is not None:
@@ -292,7 +286,7 @@ class {{ service.client_name }}(metaclass={{ service.client_name }}Meta):
292286
credentials_file=client_options.credentials_file,
293287
host=api_endpoint,
294288
scopes=client_options.scopes,
295-
ssl_channel_credentials=ssl_credentials,
289+
client_cert_source_for_mtls=client_cert_source_func,
296290
quota_project_id=client_options.quota_project_id,
297291
client_info=client_info,
298292
)

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc.py.j2

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
5151
api_mtls_endpoint: str = None,
5252
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
5353
ssl_channel_credentials: grpc.ChannelCredentials = None,
54+
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
5455
quota_project_id: Optional[str] = None,
5556
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
5657
) -> None:
@@ -82,6 +83,10 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
8283
``api_mtls_endpoint`` is None.
8384
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
8485
for grpc channel. It is ignored if ``channel`` is provided.
86+
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
87+
A callback to provide client certificate bytes and private key bytes,
88+
both in PEM format. It is used to configure mutual TLS channel. It is
89+
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
8590
quota_project_id (Optional[str]): An optional project to use for billing
8691
and quota.
8792
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
@@ -98,6 +103,11 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
98103
"""
99104
self._ssl_channel_credentials = ssl_channel_credentials
100105

106+
if api_mtls_endpoint:
107+
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
108+
if client_cert_source:
109+
warnings.warn("client_cert_source is deprecated", DeprecationWarning)
110+
101111
if channel:
102112
# Sanity check: Ensure that channel and credentials are not both
103113
# provided.
@@ -107,8 +117,6 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
107117
self._grpc_channel = channel
108118
self._ssl_channel_credentials = None
109119
elif api_mtls_endpoint:
110-
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)
111-
112120
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
113121

114122
if credentials is None:
@@ -144,12 +152,18 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
144152
if credentials is None:
145153
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
146154

155+
if client_cert_source_for_mtls and not ssl_channel_credentials:
156+
cert, key = client_cert_source_for_mtls()
157+
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
158+
certificate_chain=cert, private_key=key
159+
)
160+
147161
# create a new channel. The provided one is ignored.
148162
self._grpc_channel = type(self).create_channel(
149163
host,
150164
credentials=credentials,
151165
credentials_file=credentials_file,
152-
ssl_credentials=ssl_channel_credentials,
166+
ssl_credentials=self._ssl_channel_credentials,
153167
scopes=scopes or self.AUTH_SCOPES,
154168
quota_project_id=quota_project_id,
155169
options=[

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/grpc_asyncio.py.j2

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
9494
api_mtls_endpoint: str = None,
9595
client_cert_source: Callable[[], Tuple[bytes, bytes]] = None,
9696
ssl_channel_credentials: grpc.ChannelCredentials = None,
97+
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
9798
quota_project_id=None,
9899
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
99100
) -> None:
@@ -126,6 +127,10 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
126127
``api_mtls_endpoint`` is None.
127128
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
128129
for grpc channel. It is ignored if ``channel`` is provided.
130+
client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]):
131+
A callback to provide client certificate bytes and private key bytes,
132+
both in PEM format. It is used to configure mutual TLS channel. It is
133+
ignored if ``channel`` or ``ssl_channel_credentials`` is provided.
129134
quota_project_id (Optional[str]): An optional project to use for billing
130135
and quota.
131136
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
@@ -141,6 +146,11 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
141146
and ``credentials_file`` are passed.
142147
"""
143148
self._ssl_channel_credentials = ssl_channel_credentials
149+
150+
if api_mtls_endpoint:
151+
warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning)
152+
if client_cert_source:
153+
warnings.warn("client_cert_source is deprecated", DeprecationWarning)
144154

145155
if channel:
146156
# Sanity check: Ensure that channel and credentials are not both
@@ -151,8 +161,6 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
151161
self._grpc_channel = channel
152162
self._ssl_channel_credentials = None
153163
elif api_mtls_endpoint:
154-
warnings.warn("api_mtls_endpoint and client_cert_source are deprecated", DeprecationWarning)
155-
156164
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"
157165

158166
if credentials is None:
@@ -188,12 +196,18 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
188196
if credentials is None:
189197
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
190198

199+
if client_cert_source_for_mtls and not ssl_channel_credentials:
200+
cert, key = client_cert_source_for_mtls()
201+
self._ssl_channel_credentials = grpc.ssl_channel_credentials(
202+
certificate_chain=cert, private_key=key
203+
)
204+
191205
# create a new channel. The provided one is ignored.
192206
self._grpc_channel = type(self).create_channel(
193207
host,
194208
credentials=credentials,
195209
credentials_file=credentials_file,
196-
ssl_credentials=ssl_channel_credentials,
210+
ssl_credentials=self._ssl_channel_credentials,
197211
scopes=scopes or self.AUTH_SCOPES,
198212
quota_project_id=quota_project_id,
199213
options=[

gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
4848
credentials: credentials.Credentials = None,
4949
credentials_file: str = None,
5050
scopes: Sequence[str] = None,
51-
ssl_channel_credentials: grpc.ChannelCredentials = None,
51+
client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None,
5252
quota_project_id: Optional[str] = None,
5353
client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO,
5454
) -> None:
@@ -68,8 +68,9 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
6868
This argument is ignored if ``channel`` is provided.
6969
scopes (Optional(Sequence[str])): A list of scopes. This argument is
7070
ignored if ``channel`` is provided.
71-
ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials
72-
for grpc channel. It is ignored if ``channel`` is provided.
71+
client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client
72+
certificate to configure mutual TLS HTTP channel. It is ignored
73+
if ``channel`` is provided.
7374
quota_project_id (Optional[str]): An optional project to use for billing
7475
and quota.
7576
client_info (google.api_core.gapic_v1.client_info.ClientInfo):
@@ -89,6 +90,8 @@ class {{ service.name }}RestTransport({{ service.name }}Transport):
8990
{%- if service.has_lro %}
9091
self._operations_client = None
9192
{%- endif %}
93+
if client_cert_source_for_mtls:
94+
self._session.configure_mtls_channel(client_cert_source_for_mtls)
9295

9396
{%- if service.has_lro %}
9497

0 commit comments

Comments
 (0)