Skip to content

Commit f5b7c79

Browse files
tseaverlandrito
authored andcommitted
Reuse explicit credentials when creating 'database.spanner_api'. (googleapis#3722)
- Preserves "custom" credentials (existing code worked only with implicit credentials from the environment). - Add tests ensuring scopes are set for correctly for all GAX apis (client uses admin scope, which do not grant data access, while database uses data scope, which does not grant admin access).
1 parent 88ccc6f commit f5b7c79

File tree

3 files changed

+117
-42
lines changed

3 files changed

+117
-42
lines changed

spanner/google/cloud/spanner/database.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import re
1818

19+
import google.auth.credentials
1920
from google.gax.errors import GaxError
2021
from google.gax.grpc import exc_to_code
2122
from google.cloud.gapic.spanner.v1.spanner_client import SpannerClient
@@ -35,6 +36,9 @@
3536
# pylint: enable=ungrouped-imports
3637

3738

39+
SPANNER_DATA_SCOPE = 'https://www.googleapis.com/auth/spanner.data'
40+
41+
3842
_DATABASE_NAME_RE = re.compile(
3943
r'^projects/(?P<project>[^/]+)/'
4044
r'instances/(?P<instance_id>[a-z][-a-z0-9]*)/'
@@ -154,8 +158,14 @@ def ddl_statements(self):
154158
def spanner_api(self):
155159
"""Helper for session-related API calls."""
156160
if self._spanner_api is None:
161+
credentials = self._instance._client.credentials
162+
if isinstance(credentials, google.auth.credentials.Scoped):
163+
credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,))
157164
self._spanner_api = SpannerClient(
158-
lib_name='gccl', lib_version=__version__)
165+
lib_name='gccl',
166+
lib_version=__version__,
167+
credentials=credentials,
168+
)
159169
return self._spanner_api
160170

161171
def __eq__(self, other):

spanner/tests/unit/test_client.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -145,46 +145,56 @@ def test_admin_api_lib_name(self):
145145
__version__)
146146

147147
def test_instance_admin_api(self):
148-
from google.cloud._testing import _Monkey
149-
from google.cloud.spanner import client as MUT
148+
from google.cloud.spanner import __version__
149+
from google.cloud.spanner.client import SPANNER_ADMIN_SCOPE
150150

151-
creds = _make_credentials()
152-
client = self._make_one(project=self.PROJECT, credentials=creds)
151+
credentials = _make_credentials()
152+
client = self._make_one(project=self.PROJECT, credentials=credentials)
153+
expected_scopes = (SPANNER_ADMIN_SCOPE,)
153154

154-
class _Client(object):
155-
def __init__(self, *args, **kwargs):
156-
self.args = args
157-
self.kwargs = kwargs
155+
patch = mock.patch('google.cloud.spanner.client.InstanceAdminClient')
158156

159-
with _Monkey(MUT, InstanceAdminClient=_Client):
157+
with patch as instance_admin_client:
160158
api = client.instance_admin_api
161159

162-
self.assertTrue(isinstance(api, _Client))
160+
self.assertIs(api, instance_admin_client.return_value)
161+
162+
# API instance is cached
163163
again = client.instance_admin_api
164164
self.assertIs(again, api)
165-
self.assertEqual(api.kwargs['lib_name'], 'gccl')
166-
self.assertIs(api.kwargs['credentials'], client.credentials)
165+
166+
instance_admin_client.assert_called_once_with(
167+
lib_name='gccl',
168+
lib_version=__version__,
169+
credentials=credentials.with_scopes.return_value)
170+
171+
credentials.with_scopes.assert_called_once_with(expected_scopes)
167172

168173
def test_database_admin_api(self):
169-
from google.cloud._testing import _Monkey
170-
from google.cloud.spanner import client as MUT
174+
from google.cloud.spanner import __version__
175+
from google.cloud.spanner.client import SPANNER_ADMIN_SCOPE
171176

172-
creds = _make_credentials()
173-
client = self._make_one(project=self.PROJECT, credentials=creds)
177+
credentials = _make_credentials()
178+
client = self._make_one(project=self.PROJECT, credentials=credentials)
179+
expected_scopes = (SPANNER_ADMIN_SCOPE,)
174180

175-
class _Client(object):
176-
def __init__(self, *args, **kwargs):
177-
self.args = args
178-
self.kwargs = kwargs
181+
patch = mock.patch('google.cloud.spanner.client.DatabaseAdminClient')
179182

180-
with _Monkey(MUT, DatabaseAdminClient=_Client):
183+
with patch as database_admin_client:
181184
api = client.database_admin_api
182185

183-
self.assertTrue(isinstance(api, _Client))
186+
self.assertIs(api, database_admin_client.return_value)
187+
188+
# API instance is cached
184189
again = client.database_admin_api
185190
self.assertIs(again, api)
186-
self.assertEqual(api.kwargs['lib_name'], 'gccl')
187-
self.assertIs(api.kwargs['credentials'], client.credentials)
191+
192+
database_admin_client.assert_called_once_with(
193+
lib_name='gccl',
194+
lib_version=__version__,
195+
credentials=credentials.with_scopes.return_value)
196+
197+
credentials.with_scopes.assert_called_once_with(expected_scopes)
188198

189199
def test_copy(self):
190200
credentials = _make_credentials()

spanner/tests/unit/test_database.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,23 @@
1515

1616
import unittest
1717

18-
from google.cloud.spanner import __version__
18+
import mock
1919

2020
from google.cloud._testing import _GAXBaseAPI
2121

22+
from google.cloud.spanner import __version__
23+
24+
25+
def _make_credentials():
26+
import google.auth.credentials
27+
28+
class _CredentialsWithScopes(
29+
google.auth.credentials.Credentials,
30+
google.auth.credentials.Scoped):
31+
pass
32+
33+
return mock.Mock(spec=_CredentialsWithScopes)
34+
2235

2336
class _BaseTest(unittest.TestCase):
2437

@@ -176,30 +189,72 @@ def test_name_property(self):
176189
expected_name = self.DATABASE_NAME
177190
self.assertEqual(database.name, expected_name)
178191

179-
def test_spanner_api_property(self):
180-
from google.cloud._testing import _Monkey
181-
from google.cloud.spanner import database as MUT
182-
192+
def test_spanner_api_property_w_scopeless_creds(self):
183193
client = _Client()
194+
credentials = client.credentials = object()
184195
instance = _Instance(self.INSTANCE_NAME, client=client)
185196
pool = _Pool()
186197
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
187198

188-
_client = object()
189-
_clients = [_client]
199+
patch = mock.patch('google.cloud.spanner.database.SpannerClient')
200+
201+
with patch as spanner_client:
202+
api = database.spanner_api
203+
204+
self.assertIs(api, spanner_client.return_value)
205+
206+
# API instance is cached
207+
again = database.spanner_api
208+
self.assertIs(again, api)
209+
210+
spanner_client.assert_called_once_with(
211+
lib_name='gccl',
212+
lib_version=__version__,
213+
credentials=credentials)
190214

191-
def _mock_spanner_client(*args, **kwargs):
192-
self.assertIsInstance(args, tuple)
193-
self.assertEqual(kwargs['lib_name'], 'gccl')
194-
self.assertEqual(kwargs['lib_version'], __version__)
195-
return _clients.pop(0)
215+
def test_spanner_api_w_scoped_creds(self):
216+
import google.auth.credentials
217+
from google.cloud.spanner.database import SPANNER_DATA_SCOPE
196218

197-
with _Monkey(MUT, SpannerClient=_mock_spanner_client):
219+
class _CredentialsWithScopes(
220+
google.auth.credentials.Scoped):
221+
222+
def __init__(self, scopes=(), source=None):
223+
self._scopes = scopes
224+
self._source = source
225+
226+
def requires_scopes(self):
227+
return True
228+
229+
def with_scopes(self, scopes):
230+
return self.__class__(scopes, self)
231+
232+
expected_scopes = (SPANNER_DATA_SCOPE,)
233+
client = _Client()
234+
credentials = client.credentials = _CredentialsWithScopes()
235+
instance = _Instance(self.INSTANCE_NAME, client=client)
236+
pool = _Pool()
237+
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
238+
239+
patch = mock.patch('google.cloud.spanner.database.SpannerClient')
240+
241+
with patch as spanner_client:
198242
api = database.spanner_api
199-
self.assertIs(api, _client)
200-
# API instance is cached
201-
again = database.spanner_api
202-
self.assertIs(again, api)
243+
244+
self.assertIs(api, spanner_client.return_value)
245+
246+
# API instance is cached
247+
again = database.spanner_api
248+
self.assertIs(again, api)
249+
250+
self.assertEqual(len(spanner_client.call_args_list), 1)
251+
called_args, called_kw = spanner_client.call_args
252+
self.assertEqual(called_args, ())
253+
self.assertEqual(called_kw['lib_name'], 'gccl')
254+
self.assertEqual(called_kw['lib_version'], __version__)
255+
scoped = called_kw['credentials']
256+
self.assertEqual(scoped._scopes, expected_scopes)
257+
self.assertIs(scoped._source, credentials)
203258

204259
def test___eq__(self):
205260
instance = _Instance(self.INSTANCE_NAME)

0 commit comments

Comments
 (0)