diff --git a/spanner/google/cloud/spanner/session.py b/spanner/google/cloud/spanner/session.py index ecf0995938ef..9e0a6d740dac 100644 --- a/spanner/google/cloud/spanner/session.py +++ b/spanner/google/cloud/spanner/session.py @@ -14,6 +14,7 @@ """Wrapper for Cloud Spanner Session objects.""" +from functools import total_ordering import time from google.gax.errors import GaxError @@ -34,6 +35,7 @@ """Default timeout used by :meth:`Session.run_in_transaction`.""" +@total_ordering class Session(object): """Representation of a Cloud Spanner Session. @@ -53,6 +55,9 @@ class Session(object): def __init__(self, database): self._database = database + def __lt__(self, other): + return self._session_id < other._session_id + @property def session_id(self): """Read-only ID, set by the back-end during :meth:`create`.""" diff --git a/spanner/unit_tests/test_pool.py b/spanner/unit_tests/test_pool.py index f017fbc84e6b..e4124dcf6b99 100644 --- a/spanner/unit_tests/test_pool.py +++ b/spanner/unit_tests/test_pool.py @@ -13,6 +13,7 @@ # limitations under the License. +from functools import total_ordering import unittest @@ -597,6 +598,32 @@ def test_bind(self): self.assertTrue(pool._pending_sessions.empty()) + def test_bind_w_timestamp_race(self): + import datetime + from google.cloud._testing import _Monkey + from google.cloud.spanner import pool as MUT + NOW = datetime.datetime.utcnow() + pool = self._make_one() + database = _Database('name') + SESSIONS = [_Session(database) for _ in range(10)] + database._sessions.extend(SESSIONS) + + with _Monkey(MUT, _NOW=lambda: NOW): + pool.bind(database) + + self.assertIs(pool._database, database) + self.assertEqual(pool.size, 10) + self.assertEqual(pool.default_timeout, 10) + self.assertEqual(pool._delta.seconds, 3000) + self.assertTrue(pool._sessions.full()) + + for session in SESSIONS: + self.assertTrue(session._created) + txn = session._transaction + self.assertTrue(txn._begun) + + self.assertTrue(pool._pending_sessions.empty()) + def test_put_full(self): from six.moves.queue import Full @@ -755,6 +782,7 @@ def committed(self): return self._committed +@total_ordering class _Session(object): _transaction = None @@ -767,6 +795,9 @@ def __init__(self, database, exists=True, transaction=None): self._deleted = False self._transaction = transaction + def __lt__(self, other): + return id(self) < id(other) + def create(self): self._created = True diff --git a/spanner/unit_tests/test_session.py b/spanner/unit_tests/test_session.py index c7257adca15f..37fad4570e26 100644 --- a/spanner/unit_tests/test_session.py +++ b/spanner/unit_tests/test_session.py @@ -42,6 +42,14 @@ def test_constructor(self): self.assertTrue(session.session_id is None) self.assertTrue(session._database is database) + def test___lt___(self): + database = _Database(self.DATABASE_NAME) + lhs = self._make_one(database) + lhs._session_id = b'123' + rhs = self._make_one(database) + rhs._session_id = b'234' + self.assertTrue(lhs < rhs) + def test_name_property_wo_session_id(self): database = _Database(self.DATABASE_NAME) session = self._make_one(database)