diff --git a/gcloud/datastore/client.py b/gcloud/datastore/client.py index 3c2cecf8cfcd..26fbac0d9240 100644 --- a/gcloud/datastore/client.py +++ b/gcloud/datastore/client.py @@ -13,6 +13,7 @@ # limitations under the License. """Convenience wrapper for invoking APIs/factories w/ a dataset ID.""" +from gcloud._helpers import _LocalStack from gcloud.datastore import helpers from gcloud.datastore.batch import Batch from gcloud.datastore.entity import Entity @@ -127,8 +128,42 @@ def __init__(self, dataset_id=None, namespace=None, connection=None): if connection is None: connection = get_connection() self.connection = connection + self._batch_stack = _LocalStack() self.namespace = namespace + def _push_batch(self, batch): + """Push a batch/transaction onto our stack. + + "Protected", intended for use by batch / transaction context mgrs. + + :type batch: :class:`gcloud.datastore.batch.Batch`, or an object + implementing its API. + :param batch: newly-active batch/batch/transaction. + """ + self._batch_stack.push(batch) + + def _pop_batch(self): + """Pop a batch/transaction from our stack. + + "Protected", intended for use by batch / transaction context mgrs. + + :raises: IndexError if the stack is empty. + :rtype: :class:`gcloud.datastore.batch.Batch`, or an object + implementing its API. + :returns: the top-most batch/transaction, after removing it. + """ + return self._batch_stack.pop() + + @property + def current_batch(self): + """Currently-active batch. + + :rtype: :class:`gcloud.datastore.batch.Batch`, or an object + implementing its API, or ``NoneType`` (if no batch is active). + :returns: The batch/transaction at the toop of the batch stack. + """ + return self._batch_stack.top + def get(self, key, missing=None, deferred=None): """Retrieve an entity from a single key (if it exists). diff --git a/gcloud/datastore/test_client.py b/gcloud/datastore/test_client.py index 96334c805fef..9da4f0262946 100644 --- a/gcloud/datastore/test_client.py +++ b/gcloud/datastore/test_client.py @@ -63,6 +63,7 @@ def test_ctor_w_implicit_inputs(self): self.assertEqual(client.dataset_id, OTHER) self.assertEqual(client.namespace, None) self.assertTrue(client.connection is conn) + self.assertTrue(client.current_batch is None) def test_ctor_w_explicit_inputs(self): OTHER = 'other' @@ -74,6 +75,25 @@ def test_ctor_w_explicit_inputs(self): self.assertEqual(client.dataset_id, OTHER) self.assertEqual(client.namespace, NAMESPACE) self.assertTrue(client.connection is conn) + self.assertTrue(client.current_batch is None) + self.assertEqual(list(client._batch_stack), []) + + def test__push_connection_and__pop_connection(self): + conn = object() + batch1 = object() + batch2 = object() + client = self._makeOne(connection=conn) + client._push_batch(batch1) + self.assertEqual(list(client._batch_stack), [batch1]) + self.assertTrue(client.current_batch is batch1) + client._push_batch(batch2) + self.assertTrue(client.current_batch is batch2) + # list(_LocalStack) returns in reverse order. + self.assertEqual(list(client._batch_stack), [batch2, batch1]) + self.assertTrue(client._pop_batch() is batch2) + self.assertEqual(list(client._batch_stack), [batch1]) + self.assertTrue(client._pop_batch() is batch1) + self.assertEqual(list(client._batch_stack), []) def test_get_miss(self): _called_with = []