diff --git a/spanner/google/cloud/spanner/database.py b/spanner/google/cloud/spanner/database.py index 9b838bfaa878..acfcefdce891 100644 --- a/spanner/google/cloud/spanner/database.py +++ b/spanner/google/cloud/spanner/database.py @@ -16,6 +16,7 @@ import re +import google.auth.credentials from google.gax.errors import GaxError from google.gax.grpc import exc_to_code from google.cloud.gapic.spanner.v1.spanner_client import SpannerClient @@ -35,6 +36,9 @@ # pylint: enable=ungrouped-imports +SPANNER_DATA_SCOPE = 'https://www.googleapis.com/auth/spanner.data' + + _DATABASE_NAME_RE = re.compile( r'^projects/(?P[^/]+)/' r'instances/(?P[a-z][-a-z0-9]*)/' @@ -154,8 +158,14 @@ def ddl_statements(self): def spanner_api(self): """Helper for session-related API calls.""" if self._spanner_api is None: + credentials = self._instance._client.credentials + if isinstance(credentials, google.auth.credentials.Scoped): + credentials = credentials.with_scopes((SPANNER_DATA_SCOPE,)) self._spanner_api = SpannerClient( - lib_name='gccl', lib_version=__version__) + lib_name='gccl', + lib_version=__version__, + credentials=credentials, + ) return self._spanner_api def __eq__(self, other): diff --git a/spanner/tests/unit/test_client.py b/spanner/tests/unit/test_client.py index 28eee9b78f56..5fd79ab86ebb 100644 --- a/spanner/tests/unit/test_client.py +++ b/spanner/tests/unit/test_client.py @@ -145,46 +145,56 @@ def test_admin_api_lib_name(self): __version__) def test_instance_admin_api(self): - from google.cloud._testing import _Monkey - from google.cloud.spanner import client as MUT + from google.cloud.spanner import __version__ + from google.cloud.spanner.client import SPANNER_ADMIN_SCOPE - creds = _make_credentials() - client = self._make_one(project=self.PROJECT, credentials=creds) + credentials = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=credentials) + expected_scopes = (SPANNER_ADMIN_SCOPE,) - class _Client(object): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs + patch = mock.patch('google.cloud.spanner.client.InstanceAdminClient') - with _Monkey(MUT, InstanceAdminClient=_Client): + with patch as instance_admin_client: api = client.instance_admin_api - self.assertTrue(isinstance(api, _Client)) + self.assertIs(api, instance_admin_client.return_value) + + # API instance is cached again = client.instance_admin_api self.assertIs(again, api) - self.assertEqual(api.kwargs['lib_name'], 'gccl') - self.assertIs(api.kwargs['credentials'], client.credentials) + + instance_admin_client.assert_called_once_with( + lib_name='gccl', + lib_version=__version__, + credentials=credentials.with_scopes.return_value) + + credentials.with_scopes.assert_called_once_with(expected_scopes) def test_database_admin_api(self): - from google.cloud._testing import _Monkey - from google.cloud.spanner import client as MUT + from google.cloud.spanner import __version__ + from google.cloud.spanner.client import SPANNER_ADMIN_SCOPE - creds = _make_credentials() - client = self._make_one(project=self.PROJECT, credentials=creds) + credentials = _make_credentials() + client = self._make_one(project=self.PROJECT, credentials=credentials) + expected_scopes = (SPANNER_ADMIN_SCOPE,) - class _Client(object): - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs + patch = mock.patch('google.cloud.spanner.client.DatabaseAdminClient') - with _Monkey(MUT, DatabaseAdminClient=_Client): + with patch as database_admin_client: api = client.database_admin_api - self.assertTrue(isinstance(api, _Client)) + self.assertIs(api, database_admin_client.return_value) + + # API instance is cached again = client.database_admin_api self.assertIs(again, api) - self.assertEqual(api.kwargs['lib_name'], 'gccl') - self.assertIs(api.kwargs['credentials'], client.credentials) + + database_admin_client.assert_called_once_with( + lib_name='gccl', + lib_version=__version__, + credentials=credentials.with_scopes.return_value) + + credentials.with_scopes.assert_called_once_with(expected_scopes) def test_copy(self): credentials = _make_credentials() diff --git a/spanner/tests/unit/test_database.py b/spanner/tests/unit/test_database.py index aa1643ed7582..ec94e0198c77 100644 --- a/spanner/tests/unit/test_database.py +++ b/spanner/tests/unit/test_database.py @@ -15,10 +15,23 @@ import unittest -from google.cloud.spanner import __version__ +import mock from google.cloud._testing import _GAXBaseAPI +from google.cloud.spanner import __version__ + + +def _make_credentials(): + import google.auth.credentials + + class _CredentialsWithScopes( + google.auth.credentials.Credentials, + google.auth.credentials.Scoped): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + class _BaseTest(unittest.TestCase): @@ -176,30 +189,72 @@ def test_name_property(self): expected_name = self.DATABASE_NAME self.assertEqual(database.name, expected_name) - def test_spanner_api_property(self): - from google.cloud._testing import _Monkey - from google.cloud.spanner import database as MUT - + def test_spanner_api_property_w_scopeless_creds(self): client = _Client() + credentials = client.credentials = object() instance = _Instance(self.INSTANCE_NAME, client=client) pool = _Pool() database = self._make_one(self.DATABASE_ID, instance, pool=pool) - _client = object() - _clients = [_client] + patch = mock.patch('google.cloud.spanner.database.SpannerClient') + + with patch as spanner_client: + api = database.spanner_api + + self.assertIs(api, spanner_client.return_value) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + spanner_client.assert_called_once_with( + lib_name='gccl', + lib_version=__version__, + credentials=credentials) - def _mock_spanner_client(*args, **kwargs): - self.assertIsInstance(args, tuple) - self.assertEqual(kwargs['lib_name'], 'gccl') - self.assertEqual(kwargs['lib_version'], __version__) - return _clients.pop(0) + def test_spanner_api_w_scoped_creds(self): + import google.auth.credentials + from google.cloud.spanner.database import SPANNER_DATA_SCOPE - with _Monkey(MUT, SpannerClient=_mock_spanner_client): + class _CredentialsWithScopes( + google.auth.credentials.Scoped): + + def __init__(self, scopes=(), source=None): + self._scopes = scopes + self._source = source + + def requires_scopes(self): + return True + + def with_scopes(self, scopes): + return self.__class__(scopes, self) + + expected_scopes = (SPANNER_DATA_SCOPE,) + client = _Client() + credentials = client.credentials = _CredentialsWithScopes() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + patch = mock.patch('google.cloud.spanner.database.SpannerClient') + + with patch as spanner_client: api = database.spanner_api - self.assertIs(api, _client) - # API instance is cached - again = database.spanner_api - self.assertIs(again, api) + + self.assertIs(api, spanner_client.return_value) + + # API instance is cached + again = database.spanner_api + self.assertIs(again, api) + + self.assertEqual(len(spanner_client.call_args_list), 1) + called_args, called_kw = spanner_client.call_args + self.assertEqual(called_args, ()) + self.assertEqual(called_kw['lib_name'], 'gccl') + self.assertEqual(called_kw['lib_version'], __version__) + scoped = called_kw['credentials'] + self.assertEqual(scoped._scopes, expected_scopes) + self.assertIs(scoped._source, credentials) def test___eq__(self): instance = _Instance(self.INSTANCE_NAME)