diff --git a/gcloud/datastore/batch.py b/gcloud/datastore/batch.py index 8bc0084192bb..a137e0b64c14 100644 --- a/gcloud/datastore/batch.py +++ b/gcloud/datastore/batch.py @@ -21,16 +21,11 @@ https://cloud.google.com/datastore/docs/concepts/entities#Datastore_Batch_operations """ -from gcloud._helpers import _LocalStack -from gcloud.datastore import _implicit_environ from gcloud.datastore import helpers from gcloud.datastore.key import _dataset_ids_equal from gcloud.datastore import _datastore_v1_pb2 as datastore_pb -_BATCHES = _LocalStack() - - class Batch(object): """An abstraction representing a collected group of updates / deletes. @@ -62,34 +57,19 @@ class Batch(object): ... do_some_work(batch) ... raise Exception() # rolls back - :type dataset_id: :class:`str`. - :param dataset_id: The ID of the dataset. - - :type connection: :class:`gcloud.datastore.connection.Connection` - :param connection: The connection used to connect to datastore. - - :raises: :class:`ValueError` if either a connection or dataset ID - are not set. + :type client: :class:`gcloud.datastore.client.Client` + :param client: The client used to connect to datastore. """ _id = None # "protected" attribute, always None for non-transactions - def __init__(self, dataset_id=None, connection=None): - self._connection = (connection or - _implicit_environ.get_default_connection()) - self._dataset_id = (dataset_id or - _implicit_environ.get_default_dataset_id()) - - if self._connection is None or self._dataset_id is None: - raise ValueError('A batch must have a connection and ' - 'a dataset ID set.') - + def __init__(self, client): + self._client = client self._mutation = datastore_pb.Mutation() self._auto_id_entities = [] - @staticmethod - def current(): + def current(self): """Return the topmost batch / transaction, or None.""" - return _BATCHES.top + return self._client.current_batch @property def dataset_id(self): @@ -98,7 +78,16 @@ def dataset_id(self): :rtype: :class:`str` :returns: The dataset ID in which the batch will run. """ - return self._dataset_id + return self._client.dataset_id + + @property + def namespace(self): + """Getter for namespace in which the batch will run. + + :rtype: :class:`str` + :returns: The namespace in which the batch will run. + """ + return self._client.namespace @property def connection(self): @@ -107,7 +96,7 @@ def connection(self): :rtype: :class:`gcloud.datastore.connection.Connection` :returns: The connection over which the batch will run. """ - return self._connection + return self._client.connection @property def mutation(self): @@ -172,7 +161,7 @@ def put(self, entity): if entity.key is None: raise ValueError("Entity must have a key") - if not _dataset_ids_equal(self._dataset_id, entity.key.dataset_id): + if not _dataset_ids_equal(self.dataset_id, entity.key.dataset_id): raise ValueError("Key must be from same dataset as batch") _assign_entity_to_mutation( @@ -190,7 +179,7 @@ def delete(self, key): if key.is_partial: raise ValueError("Key must be complete") - if not _dataset_ids_equal(self._dataset_id, key.dataset_id): + if not _dataset_ids_equal(self.dataset_id, key.dataset_id): raise ValueError("Key must be from same dataset as batch") key_pb = helpers._prepare_key_for_request(key.to_protobuf()) @@ -211,7 +200,7 @@ def commit(self): context manager. """ response = self.connection.commit( - self._dataset_id, self.mutation, self._id) + self.dataset_id, self.mutation, self._id) # If the back-end returns without error, we are guaranteed that # the response's 'insert_auto_id_key' will match (length and order) # the request's 'insert_auto_id` entities, which are derived from @@ -229,7 +218,7 @@ def rollback(self): pass def __enter__(self): - _BATCHES.push(self) + self._client._push_batch(self) self.begin() return self @@ -240,7 +229,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): else: self.rollback() finally: - _BATCHES.pop() + self._client._pop_batch() def _assign_entity_to_mutation(mutation_pb, entity, auto_id_entities): diff --git a/gcloud/datastore/client.py b/gcloud/datastore/client.py index 26fbac0d9240..0dc60850f1eb 100644 --- a/gcloud/datastore/client.py +++ b/gcloud/datastore/client.py @@ -164,6 +164,19 @@ def current_batch(self): """ return self._batch_stack.top + @property + def current_transaction(self): + """Currently-active transaction. + + :rtype: :class:`gcloud.datastore.transaction.Transaction`, or an object + implementing its API, or ``NoneType`` (if no transaction is + active). + :returns: The transaction at the toop of the batch stack. + """ + transaction = self.current_batch + if isinstance(transaction, Transaction): + return transaction + def get(self, key, missing=None, deferred=None): """Retrieve an entity from a single key (if it exists). @@ -222,7 +235,7 @@ def get_multi(self, keys, missing=None, deferred=None): if ids != [self.dataset_id]: raise ValueError('Keys do not match dataset ID') - transaction = Transaction.current() + transaction = self.current_transaction entity_pbs = _extended_lookup( connection=self.connection, @@ -274,12 +287,12 @@ def put_multi(self, entities): if not entities: return - current = Batch.current() + current = self.current_batch in_batch = current is not None if not in_batch: - current = Batch(dataset_id=self.dataset_id, - connection=self.connection) + current = self.batch() + for entity in entities: current.put(entity) @@ -310,12 +323,12 @@ def delete_multi(self, keys): return # We allow partial keys to attempt a delete, the backend will fail. - current = Batch.current() + current = self.current_batch in_batch = current is not None if not in_batch: - current = Batch(dataset_id=self.dataset_id, - connection=self.connection) + current = self.batch() + for key in keys: current.delete(key) @@ -368,15 +381,14 @@ def batch(self): Passes our ``dataset_id``. """ - return Batch(dataset_id=self.dataset_id, connection=self.connection) + return Batch(self) def transaction(self): """Proxy to :class:`gcloud.datastore.transaction.Transaction`. Passes our ``dataset_id``. """ - return Transaction(dataset_id=self.dataset_id, - connection=self.connection) + return Transaction(self) def query(self, **kwargs): """Proxy to :class:`gcloud.datastore.query.Query`. diff --git a/gcloud/datastore/query.py b/gcloud/datastore/query.py index b28bbf0ba0ce..4adfece8520d 100644 --- a/gcloud/datastore/query.py +++ b/gcloud/datastore/query.py @@ -17,11 +17,9 @@ import base64 from gcloud._helpers import _ensure_tuple_or_list -from gcloud.datastore import _implicit_environ from gcloud.datastore import _datastore_v1_pb2 as datastore_pb from gcloud.datastore import helpers from gcloud.datastore.key import Key -from gcloud.datastore.transaction import Transaction class Query(object): @@ -30,15 +28,19 @@ class Query(object): This class serves as an abstraction for creating a query over data stored in the Cloud Datastore. + :type client: :class:`gcloud.datastore.client.Client` + :param client: The client used to connect to datastore. + :type kind: string :param kind: The kind to query. :type dataset_id: string :param dataset_id: The ID of the dataset to query. If not passed, - uses the implicit default. + uses the client's value. :type namespace: string or None - :param namespace: The namespace to which to restrict results. + :param namespace: The namespace to which to restrict results. If not + passed, uses the client's value. :type ancestor: :class:`gcloud.datastore.key.Key` or None :param ancestor: key of the ancestor to which this query's results are @@ -71,6 +73,7 @@ class Query(object): """Mapping of operator strings and their protobuf equivalents.""" def __init__(self, + client, kind=None, dataset_id=None, namespace=None, @@ -80,15 +83,10 @@ def __init__(self, order=(), group_by=()): - if dataset_id is None: - dataset_id = _implicit_environ.get_default_dataset_id() - - if dataset_id is None: - raise ValueError("No dataset ID supplied, and no default set.") - - self._dataset_id = dataset_id + self._client = client self._kind = kind - self._namespace = namespace + self._dataset_id = dataset_id or client.dataset_id + self._namespace = namespace or client.namespace self._ancestor = ancestor self._filters = [] # Verify filters passed in. @@ -294,7 +292,7 @@ def group_by(self, value): self._group_by[:] = value def fetch(self, limit=None, offset=0, start_cursor=None, end_cursor=None, - connection=None): + client=None): """Execute the Query; return an iterator for the matching entities. For example:: @@ -319,22 +317,19 @@ def fetch(self, limit=None, offset=0, start_cursor=None, end_cursor=None, :type end_cursor: bytes :param end_cursor: An optional cursor passed through to the iterator. - :type connection: :class:`gcloud.datastore.connection.Connection` - :param connection: An optional cursor passed through to the iterator. - If not supplied, uses the implicit default. + :type client: :class:`gcloud.datastore.client.Client` + :param client: client used to connect to datastore. + If not supplied, uses the query's value. :rtype: :class:`Iterator` :raises: ValueError if ``connection`` is not passed and no implicit default has been set. """ - if connection is None: - connection = _implicit_environ.get_default_connection() - - if connection is None: - raise ValueError("No connection passed, and no default set") + if client is None: + client = self._client return Iterator( - self, connection, limit, offset, start_cursor, end_cursor) + self, client, limit, offset, start_cursor, end_cursor) class Iterator(object): @@ -347,10 +342,10 @@ class Iterator(object): datastore_pb.QueryResultBatch.MORE_RESULTS_AFTER_LIMIT, ) - def __init__(self, query, connection, limit=None, offset=0, + def __init__(self, query, client, limit=None, offset=0, start_cursor=None, end_cursor=None): self._query = query - self._connection = connection + self._client = client self._limit = limit self._offset = offset self._start_cursor = start_cursor @@ -380,9 +375,9 @@ def next_page(self): pb.offset = self._offset - transaction = Transaction.current() + transaction = self._client.current_transaction - query_results = self._connection.run_query( + query_results = self._client.connection.run_query( query_pb=pb, dataset_id=self._query.dataset_id, namespace=self._query.namespace, diff --git a/gcloud/datastore/test_batch.py b/gcloud/datastore/test_batch.py index c00c58156f23..a030560d834f 100644 --- a/gcloud/datastore/test_batch.py +++ b/gcloud/datastore/test_batch.py @@ -22,41 +22,20 @@ def _getTargetClass(self): return Batch - def _makeOne(self, dataset_id=None, connection=None): - return self._getTargetClass()(dataset_id=dataset_id, - connection=connection) + def _makeOne(self, client): + return self._getTargetClass()(client) - def test_ctor_missing_required(self): - from gcloud.datastore._testing import _monkey_defaults - - with _monkey_defaults(): - self.assertRaises(ValueError, self._makeOne) - self.assertRaises(ValueError, self._makeOne, dataset_id=object()) - self.assertRaises(ValueError, self._makeOne, connection=object()) - - def test_ctor_explicit(self): + def test_ctor(self): from gcloud.datastore._datastore_v1_pb2 import Mutation _DATASET = 'DATASET' + _NAMESPACE = 'NAMESPACE' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection, _NAMESPACE) + batch = self._makeOne(client) self.assertEqual(batch.dataset_id, _DATASET) self.assertEqual(batch.connection, connection) - self.assertTrue(batch._id is None) - self.assertTrue(isinstance(batch.mutation, Mutation)) - self.assertEqual(batch._auto_id_entities, []) - - def test_ctor_implicit(self): - from gcloud.datastore._testing import _monkey_defaults - from gcloud.datastore._datastore_v1_pb2 import Mutation - _DATASET = 'DATASET' - CONNECTION = _Connection() - - with _monkey_defaults(connection=CONNECTION, dataset_id=_DATASET): - batch = self._makeOne() - - self.assertEqual(batch.dataset_id, _DATASET) - self.assertEqual(batch.connection, CONNECTION) + self.assertEqual(batch.namespace, _NAMESPACE) self.assertTrue(batch._id is None) self.assertTrue(isinstance(batch.mutation, Mutation)) self.assertEqual(batch._auto_id_entities, []) @@ -64,8 +43,9 @@ def test_ctor_implicit(self): def test_current(self): _DATASET = 'DATASET' connection = _Connection() - batch1 = self._makeOne(_DATASET, connection) - batch2 = self._makeOne(_DATASET, connection) + client = _Client(_DATASET, connection) + batch1 = self._makeOne(client) + batch2 = self._makeOne(client) self.assertTrue(batch1.current() is None) self.assertTrue(batch2.current() is None) with batch1: @@ -82,7 +62,8 @@ def test_current(self): def test_add_auto_id_entity_w_partial_key(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) entity = _Entity() key = entity.key = _Key(_DATASET) key._id = None @@ -94,7 +75,8 @@ def test_add_auto_id_entity_w_partial_key(self): def test_add_auto_id_entity_w_completed_key(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) entity = _Entity() entity.key = _Key(_DATASET) @@ -103,14 +85,16 @@ def test_add_auto_id_entity_w_completed_key(self): def test_put_entity_wo_key(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) self.assertRaises(ValueError, batch.put, _Entity()) def test_put_entity_w_key_wrong_dataset_id(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) entity = _Entity() entity.key = _Key('OTHER') @@ -120,7 +104,8 @@ def test_put_entity_w_partial_key(self): _DATASET = 'DATASET' _PROPERTIES = {'foo': 'bar'} connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) entity = _Entity(_PROPERTIES) key = entity.key = _Key(_DATASET) key._id = None @@ -145,7 +130,8 @@ def test_put_entity_w_completed_key(self): 'frotz': [], # will be ignored } connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) entity = _Entity(_PROPERTIES) entity.exclude_from_indexes = ('baz', 'spam') key = entity.key = _Key(_DATASET) @@ -180,7 +166,8 @@ def test_put_entity_w_completed_key_prefixed_dataset_id(self): 'frotz': [], # will be ignored } connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) entity = _Entity(_PROPERTIES) entity.exclude_from_indexes = ('baz', 'spam') key = entity.key = _Key('s~' + _DATASET) @@ -209,7 +196,8 @@ def test_put_entity_w_completed_key_prefixed_dataset_id(self): def test_delete_w_partial_key(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) key = _Key(_DATASET) key._id = None @@ -218,7 +206,8 @@ def test_delete_w_partial_key(self): def test_delete_w_key_wrong_dataset_id(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) key = _Key('OTHER') self.assertRaises(ValueError, batch.delete, key) @@ -226,7 +215,8 @@ def test_delete_w_key_wrong_dataset_id(self): def test_delete_w_completed_key(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) key = _Key(_DATASET) batch.delete(key) @@ -242,7 +232,8 @@ def test_delete_w_completed_key(self): def test_delete_w_completed_key_w_prefixed_dataset_id(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) key = _Key('s~' + _DATASET) batch.delete(key) @@ -258,7 +249,8 @@ def test_delete_w_completed_key_w_prefixed_dataset_id(self): def test_commit(self): _DATASET = 'DATASET' connection = _Connection() - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) batch.commit() @@ -269,7 +261,8 @@ def test_commit_w_auto_id_entities(self): _DATASET = 'DATASET' _NEW_ID = 1234 connection = _Connection(_NEW_ID) - batch = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + batch = self._makeOne(client) entity = _Entity({}) key = entity.key = _Key(_DATASET) key._id = None @@ -283,21 +276,20 @@ def test_commit_w_auto_id_entities(self): self.assertEqual(entity.key._id, _NEW_ID) def test_as_context_mgr_wo_error(self): - from gcloud.datastore.batch import _BATCHES _DATASET = 'DATASET' _PROPERTIES = {'foo': 'bar'} connection = _Connection() entity = _Entity(_PROPERTIES) key = entity.key = _Key(_DATASET) - self.assertEqual(list(_BATCHES), []) + client = _Client(_DATASET, connection) + self.assertEqual(list(client._batches), []) - with self._makeOne(dataset_id=_DATASET, - connection=connection) as batch: - self.assertEqual(list(_BATCHES), [batch]) + with self._makeOne(client) as batch: + self.assertEqual(list(client._batches), [batch]) batch.put(entity) - self.assertEqual(list(_BATCHES), []) + self.assertEqual(list(client._batches), []) insert_auto_ids = list(batch.mutation.insert_auto_id) self.assertEqual(len(insert_auto_ids), 0) @@ -310,7 +302,6 @@ def test_as_context_mgr_wo_error(self): [(_DATASET, batch.mutation, None)]) def test_as_context_mgr_nested(self): - from gcloud.datastore.batch import _BATCHES _DATASET = 'DATASET' _PROPERTIES = {'foo': 'bar'} connection = _Connection() @@ -319,20 +310,19 @@ def test_as_context_mgr_nested(self): entity2 = _Entity(_PROPERTIES) key2 = entity2.key = _Key(_DATASET) - self.assertEqual(list(_BATCHES), []) + client = _Client(_DATASET, connection) + self.assertEqual(list(client._batches), []) - with self._makeOne(dataset_id=_DATASET, - connection=connection) as batch1: - self.assertEqual(list(_BATCHES), [batch1]) + with self._makeOne(client) as batch1: + self.assertEqual(list(client._batches), [batch1]) batch1.put(entity1) - with self._makeOne(dataset_id=_DATASET, - connection=connection) as batch2: - self.assertEqual(list(_BATCHES), [batch2, batch1]) + with self._makeOne(client) as batch2: + self.assertEqual(list(client._batches), [batch2, batch1]) batch2.put(entity2) - self.assertEqual(list(_BATCHES), [batch1]) + self.assertEqual(list(client._batches), [batch1]) - self.assertEqual(list(_BATCHES), []) + self.assertEqual(list(client._batches), []) insert_auto_ids = list(batch1.mutation.insert_auto_id) self.assertEqual(len(insert_auto_ids), 0) @@ -355,25 +345,24 @@ def test_as_context_mgr_nested(self): (_DATASET, batch1.mutation, None)]) def test_as_context_mgr_w_error(self): - from gcloud.datastore.batch import _BATCHES _DATASET = 'DATASET' _PROPERTIES = {'foo': 'bar'} connection = _Connection() entity = _Entity(_PROPERTIES) key = entity.key = _Key(_DATASET) - self.assertEqual(list(_BATCHES), []) + client = _Client(_DATASET, connection) + self.assertEqual(list(client._batches), []) try: - with self._makeOne(dataset_id=_DATASET, - connection=connection) as batch: - self.assertEqual(list(_BATCHES), [batch]) + with self._makeOne(client) as batch: + self.assertEqual(list(client._batches), [batch]) batch.put(entity) raise ValueError("testing") except ValueError: pass - self.assertEqual(list(_BATCHES), []) + self.assertEqual(list(client._batches), []) insert_auto_ids = list(batch.mutation.insert_auto_id) self.assertEqual(len(insert_auto_ids), 0) @@ -454,3 +443,23 @@ def completed_key(self, new_id): new_key = self.__class__(self.dataset_id) new_key._id = new_id return new_key + + +class _Client(object): + + def __init__(self, dataset_id, connection, namespace=None): + self.dataset_id = dataset_id + self.connection = connection + self.namespace = namespace + self._batches = [] + + def _push_batch(self, batch): + self._batches.insert(0, batch) + + def _pop_batch(self): + return self._batches.pop(0) + + @property + def current_batch(self): + if self._batches: + return self._batches[0] diff --git a/gcloud/datastore/test_client.py b/gcloud/datastore/test_client.py index 9da4f0262946..a9fbc953ad29 100644 --- a/gcloud/datastore/test_client.py +++ b/gcloud/datastore/test_client.py @@ -64,6 +64,7 @@ def test_ctor_w_implicit_inputs(self): self.assertEqual(client.namespace, None) self.assertTrue(client.connection is conn) self.assertTrue(client.current_batch is None) + self.assertTrue(client.current_transaction is None) def test_ctor_w_explicit_inputs(self): OTHER = 'other' @@ -78,21 +79,23 @@ def test_ctor_w_explicit_inputs(self): self.assertTrue(client.current_batch is None) self.assertEqual(list(client._batch_stack), []) - def test__push_connection_and__pop_connection(self): + def test__push_batch_and__pop_batch(self): conn = object() - batch1 = object() - batch2 = object() client = self._makeOne(connection=conn) - client._push_batch(batch1) - self.assertEqual(list(client._batch_stack), [batch1]) - self.assertTrue(client.current_batch is batch1) - client._push_batch(batch2) - self.assertTrue(client.current_batch is batch2) + batch = client.batch() + xact = client.transaction() + client._push_batch(batch) + self.assertEqual(list(client._batch_stack), [batch]) + self.assertTrue(client.current_batch is batch) + self.assertTrue(client.current_transaction is None) + client._push_batch(xact) + self.assertTrue(client.current_batch is xact) + self.assertTrue(client.current_transaction is xact) # list(_LocalStack) returns in reverse order. - self.assertEqual(list(client._batch_stack), [batch2, batch1]) - self.assertTrue(client._pop_batch() is batch2) - self.assertEqual(list(client._batch_stack), [batch1]) - self.assertTrue(client._pop_batch() is batch1) + self.assertEqual(list(client._batch_stack), [xact, batch]) + self.assertTrue(client._pop_batch() is xact) + self.assertEqual(list(client._batch_stack), [batch]) + self.assertTrue(client._pop_batch() is batch) self.assertEqual(list(client._batch_stack), []) def test_get_miss(self): @@ -464,7 +467,7 @@ def test_put_multi_existing_batch_w_completed_key(self): entity = _Entity(foo=u'bar') key = entity.key = _Key(self.DATASET_ID) - with _NoCommitBatch(self.DATASET_ID, connection) as CURR_BATCH: + with _NoCommitBatch(client) as CURR_BATCH: result = client.put_multi([entity]) self.assertEqual(result, None) @@ -521,7 +524,7 @@ def test_delete_multi_w_existing_batch(self): client = self._makeOne(connection=connection) key = _Key(self.DATASET_ID) - with _NoCommitBatch(self.DATASET_ID, connection) as CURR_BATCH: + with _NoCommitBatch(client) as CURR_BATCH: result = client.delete_multi([key]) self.assertEqual(result, None) @@ -540,13 +543,13 @@ def test_delete_multi_w_existing_transaction(self): client = self._makeOne(connection=connection) key = _Key(self.DATASET_ID) - with _NoCommitTransaction(self.DATASET_ID, connection) as CURR_BATCH: + with _NoCommitTransaction(client) as CURR_XACT: result = client.delete_multi([key]) self.assertEqual(result, None) - self.assertEqual(len(CURR_BATCH.mutation.insert_auto_id), 0) - self.assertEqual(len(CURR_BATCH.mutation.upsert), 0) - deletes = list(CURR_BATCH.mutation.delete) + self.assertEqual(len(CURR_XACT.mutation.insert_auto_id), 0) + self.assertEqual(len(CURR_XACT.mutation.upsert), 0) + deletes = list(CURR_XACT.mutation.delete) self.assertEqual(len(deletes), 1) self.assertEqual(deletes[0], key._key) self.assertEqual(len(connection._committed), 0) @@ -638,7 +641,7 @@ def test_key_w_namespace_collision(self): } self.assertEqual(key.kwargs, expected_kwargs) - def test_batch_wo_connection(self): + def test_batch(self): from gcloud.datastore import client as MUT from gcloud._testing import _Monkey @@ -648,28 +651,10 @@ def test_batch_wo_connection(self): batch = client.batch() self.assertTrue(isinstance(batch, _Dummy)) - self.assertEqual(batch.args, ()) - self.assertEqual(batch.kwargs, - {'dataset_id': self.DATASET_ID, - 'connection': self.CONNECTION}) - - def test_batch_w_connection(self): - from gcloud.datastore import client as MUT - from gcloud._testing import _Monkey - - connection = object() - client = self._makeOne(connection=connection) - - with _Monkey(MUT, Batch=_Dummy): - batch = client.batch() - - self.assertTrue(isinstance(batch, _Dummy)) - self.assertEqual(batch.args, ()) - self.assertEqual(batch.kwargs, - {'dataset_id': self.DATASET_ID, - 'connection': connection}) + self.assertEqual(batch.args, (client,)) + self.assertEqual(batch.kwargs, {}) - def test_transaction_wo_connection(self): + def test_transaction(self): from gcloud.datastore import client as MUT from gcloud._testing import _Monkey @@ -679,28 +664,8 @@ def test_transaction_wo_connection(self): xact = client.transaction() self.assertTrue(isinstance(xact, _Dummy)) - self.assertEqual(xact.args, ()) - self.assertEqual(xact.kwargs, - {'dataset_id': self.DATASET_ID, - 'connection': self.CONNECTION}) - - def test_transaction_w_connection(self): - from gcloud.datastore import client as MUT - from gcloud._testing import _Monkey - - conn = object() - client = self._makeOne(connection=conn) - - with _Monkey(MUT, Transaction=_Dummy): - xact = client.transaction() - - self.assertTrue(isinstance(xact, _Dummy)) - self.assertEqual(xact.args, ()) - expected_kwargs = { - 'dataset_id': self.DATASET_ID, - 'connection': conn, - } - self.assertEqual(xact.kwargs, expected_kwargs) + self.assertEqual(xact.args, (client,)) + self.assertEqual(xact.kwargs, {}) def test_query_w_dataset_id(self): KIND = 'KIND' @@ -824,32 +789,30 @@ def request(self, **kw): class _NoCommitBatch(object): - def __init__(self, dataset_id, connection): + def __init__(self, client): from gcloud.datastore.batch import Batch - self._batch = Batch(dataset_id, connection) + self._client = client + self._batch = Batch(client) def __enter__(self): - from gcloud.datastore.batch import _BATCHES - _BATCHES.push(self._batch) + self._client._push_batch(self._batch) return self._batch def __exit__(self, *args): - from gcloud.datastore.batch import _BATCHES - _BATCHES.pop() + self._client._pop_batch() class _NoCommitTransaction(object): - def __init__(self, dataset_id, connection, transaction_id='TRANSACTION'): + def __init__(self, client, transaction_id='TRANSACTION'): from gcloud.datastore.transaction import Transaction - xact = self._transaction = Transaction(dataset_id, connection) + self._client = client + xact = self._transaction = Transaction(client) xact._id = transaction_id def __enter__(self): - from gcloud.datastore.batch import _BATCHES - _BATCHES.push(self._transaction) + self._client._push_batch(self._transaction) return self._transaction def __exit__(self, *args): - from gcloud.datastore.batch import _BATCHES - _BATCHES.pop() + self._client._pop_batch() diff --git a/gcloud/datastore/test_query.py b/gcloud/datastore/test_query.py index b3488d556a75..a3cdaf4982bb 100644 --- a/gcloud/datastore/test_query.py +++ b/gcloud/datastore/test_query.py @@ -17,13 +17,7 @@ class TestQuery(unittest2.TestCase): - def setUp(self): - from gcloud.datastore._testing import _setup_defaults - _setup_defaults(self) - - def tearDown(self): - from gcloud.datastore._testing import _tear_down_defaults - _tear_down_defaults(self) + _DATASET = 'DATASET' def _getTargetClass(self): from gcloud.datastore.query import Query @@ -32,18 +26,18 @@ def _getTargetClass(self): def _makeOne(self, *args, **kw): return self._getTargetClass()(*args, **kw) - def test_ctor_defaults_wo_implicit_dataset_id(self): - self.assertRaises(ValueError, self._makeOne) - - def test_ctor_defaults_w_implicit_dataset_id(self): - from gcloud.datastore._testing import _monkey_defaults + def _makeClient(self, connection=None): + if connection is None: + connection = _Connection() + return _Client(self._DATASET, connection) - _DATASET = 'DATASET' - with _monkey_defaults(dataset_id=_DATASET): - query = self._makeOne() - self.assertEqual(query.dataset_id, _DATASET) + def test_ctor_defaults(self): + client = self._makeClient() + query = self._makeOne(client) + self.assertTrue(query._client is client) + self.assertEqual(query.dataset_id, client.dataset_id) self.assertEqual(query.kind, None) - self.assertEqual(query.namespace, None) + self.assertEqual(query.namespace, client.namespace) self.assertEqual(query.ancestor, None) self.assertEqual(query.filters, []) self.assertEqual(query.projection, []) @@ -52,15 +46,17 @@ def test_ctor_defaults_w_implicit_dataset_id(self): def test_ctor_explicit(self): from gcloud.datastore.key import Key - _DATASET = 'DATASET' + _DATASET = 'OTHER_DATASET' _KIND = 'KIND' - _NAMESPACE = 'NAMESPACE' + _NAMESPACE = 'OTHER_NAMESPACE' + client = self._makeClient() ancestor = Key('ANCESTOR', 123, dataset_id=_DATASET) FILTERS = [('foo', '=', 'Qux'), ('bar', '<', 17)] PROJECTION = ['foo', 'bar', 'baz'] ORDER = ['foo', 'bar'] GROUP_BY = ['foo'] query = self._makeOne( + client, kind=_KIND, dataset_id=_DATASET, namespace=_NAMESPACE, @@ -70,6 +66,7 @@ def test_ctor_explicit(self): order=ORDER, group_by=GROUP_BY, ) + self.assertTrue(query._client is client) self.assertEqual(query.dataset_id, _DATASET) self.assertEqual(query.kind, _KIND) self.assertEqual(query.namespace, _NAMESPACE) @@ -80,36 +77,27 @@ def test_ctor_explicit(self): self.assertEqual(query.group_by, GROUP_BY) def test_ctor_bad_projection(self): - _DATASET = 'DATASET' - _KIND = 'KIND' BAD_PROJECTION = object() - self.assertRaises(TypeError, self._makeOne, _KIND, _DATASET, + self.assertRaises(TypeError, self._makeOne, self._makeClient(), projection=BAD_PROJECTION) def test_ctor_bad_order(self): - _DATASET = 'DATASET' - _KIND = 'KIND' BAD_ORDER = object() - self.assertRaises(TypeError, self._makeOne, _KIND, _DATASET, + self.assertRaises(TypeError, self._makeOne, self._makeClient(), order=BAD_ORDER) def test_ctor_bad_group_by(self): - _DATASET = 'DATASET' - _KIND = 'KIND' BAD_GROUP_BY = object() - self.assertRaises(TypeError, self._makeOne, _KIND, _DATASET, + self.assertRaises(TypeError, self._makeOne, self._makeClient(), group_by=BAD_GROUP_BY) def test_ctor_bad_filters(self): - _DATASET = 'DATASET' - _KIND = 'KIND' FILTERS_CANT_UNPACK = [('one', 'two')] - self.assertRaises(ValueError, self._makeOne, _KIND, _DATASET, + self.assertRaises(ValueError, self._makeOne, self._makeClient(), filters=FILTERS_CANT_UNPACK) def test_namespace_setter_w_non_string(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) def _assign(val): query.namespace = val @@ -117,16 +105,13 @@ def _assign(val): self.assertRaises(ValueError, _assign, object()) def test_namespace_setter(self): - _DATASET = 'DATASET' - _NAMESPACE = 'NAMESPACE' - query = self._makeOne(dataset_id=_DATASET) + _NAMESPACE = 'OTHER_NAMESPACE' + query = self._makeOne(self._makeClient()) query.namespace = _NAMESPACE - self.assertEqual(query.dataset_id, _DATASET) self.assertEqual(query.namespace, _NAMESPACE) def test_kind_setter_w_non_string(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) def _assign(val): query.kind = val @@ -134,26 +119,22 @@ def _assign(val): self.assertRaises(TypeError, _assign, object()) def test_kind_setter_wo_existing(self): - _DATASET = 'DATASET' _KIND = 'KIND' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) query.kind = _KIND - self.assertEqual(query.dataset_id, _DATASET) self.assertEqual(query.kind, _KIND) def test_kind_setter_w_existing(self): - _DATASET = 'DATASET' _KIND_BEFORE = 'KIND_BEFORE' _KIND_AFTER = 'KIND_AFTER' - query = self._makeOne(_KIND_BEFORE, _DATASET) + query = self._makeOne(self._makeClient(), kind=_KIND_BEFORE) self.assertEqual(query.kind, _KIND_BEFORE) query.kind = _KIND_AFTER - self.assertEqual(query.dataset_id, _DATASET) + self.assertEqual(query.dataset_id, self._DATASET) self.assertEqual(query.kind, _KIND_AFTER) def test_ancestor_setter_w_non_key(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) def _assign(val): query.ancestor = val @@ -163,37 +144,32 @@ def _assign(val): def test_ancestor_setter_w_key(self): from gcloud.datastore.key import Key - _DATASET = 'DATASET' _NAME = u'NAME' - key = Key('KIND', 123, dataset_id='DATASET') - query = self._makeOne(dataset_id=_DATASET) + key = Key('KIND', 123, dataset_id=self._DATASET) + query = self._makeOne(self._makeClient()) query.add_filter('name', '=', _NAME) query.ancestor = key self.assertEqual(query.ancestor.path, key.path) def test_ancestor_deleter_w_key(self): from gcloud.datastore.key import Key - _DATASET = 'DATASET' - key = Key('KIND', 123, dataset_id='DATASET') - query = self._makeOne(dataset_id=_DATASET, ancestor=key) + key = Key('KIND', 123, dataset_id=self._DATASET) + query = self._makeOne(client=self._makeClient(), ancestor=key) del query.ancestor self.assertTrue(query.ancestor is None) def test_add_filter_setter_w_unknown_operator(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) self.assertRaises(ValueError, query.add_filter, 'firstname', '~~', 'John') def test_add_filter_w_known_operator(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) query.add_filter('firstname', '=', u'John') self.assertEqual(query.filters, [('firstname', '=', u'John')]) def test_add_filter_w_all_operators(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) query.add_filter('leq_prop', '<=', u'val1') query.add_filter('geq_prop', '>=', u'val2') query.add_filter('lt_prop', '<', u'val3') @@ -208,8 +184,7 @@ def test_add_filter_w_all_operators(self): def test_add_filter_w_known_operator_and_entity(self): from gcloud.datastore.entity import Entity - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) other = Entity() other['firstname'] = u'John' other['lastname'] = u'Smith' @@ -217,159 +192,120 @@ def test_add_filter_w_known_operator_and_entity(self): self.assertEqual(query.filters, [('other', '=', other)]) def test_add_filter_w_whitespace_property_name(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) PROPERTY_NAME = ' property with lots of space ' query.add_filter(PROPERTY_NAME, '=', u'John') self.assertEqual(query.filters, [(PROPERTY_NAME, '=', u'John')]) def test_add_filter___key__valid_key(self): from gcloud.datastore.key import Key - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) - key = Key('Foo', dataset_id='DATASET') + query = self._makeOne(self._makeClient()) + key = Key('Foo', dataset_id=self._DATASET) query.add_filter('__key__', '=', key) self.assertEqual(query.filters, [('__key__', '=', key)]) def test_filter___key__not_equal_operator(self): from gcloud.datastore.key import Key - _DATASET = 'DATASET' - key = Key('Foo', dataset_id='DATASET') - query = self._makeOne(dataset_id=_DATASET) + key = Key('Foo', dataset_id=self._DATASET) + query = self._makeOne(self._makeClient()) query.add_filter('__key__', '<', key) self.assertEqual(query.filters, [('__key__', '<', key)]) def test_filter___key__invalid_value(self): - _DATASET = 'DATASET' - query = self._makeOne(dataset_id=_DATASET) + query = self._makeOne(self._makeClient()) self.assertRaises(ValueError, query.add_filter, '__key__', '=', None) def test_projection_setter_empty(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.projection = [] self.assertEqual(query.projection, []) def test_projection_setter_string(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.projection = 'field1' self.assertEqual(query.projection, ['field1']) def test_projection_setter_non_empty(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.projection = ['field1', 'field2'] self.assertEqual(query.projection, ['field1', 'field2']) def test_projection_setter_multiple_calls(self): - _DATASET = 'DATASET' - _KIND = 'KIND' _PROJECTION1 = ['field1', 'field2'] _PROJECTION2 = ['field3'] - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.projection = _PROJECTION1 self.assertEqual(query.projection, _PROJECTION1) query.projection = _PROJECTION2 self.assertEqual(query.projection, _PROJECTION2) def test_keys_only(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.keys_only() self.assertEqual(query.projection, ['__key__']) def test_order_setter_empty(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET, order=['foo', '-bar']) + query = self._makeOne(self._makeClient(), order=['foo', '-bar']) query.order = [] self.assertEqual(query.order, []) def test_order_setter_string(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.order = 'field' self.assertEqual(query.order, ['field']) def test_order_setter_single_item_list_desc(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.order = ['-field'] self.assertEqual(query.order, ['-field']) def test_order_setter_multiple(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.order = ['foo', '-bar'] self.assertEqual(query.order, ['foo', '-bar']) def test_group_by_setter_empty(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET, group_by=['foo', 'bar']) + query = self._makeOne(self._makeClient(), group_by=['foo', 'bar']) query.group_by = [] self.assertEqual(query.group_by, []) def test_group_by_setter_string(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.group_by = 'field1' self.assertEqual(query.group_by, ['field1']) def test_group_by_setter_non_empty(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.group_by = ['field1', 'field2'] self.assertEqual(query.group_by, ['field1', 'field2']) def test_group_by_multiple_calls(self): - _DATASET = 'DATASET' - _KIND = 'KIND' _GROUP_BY1 = ['field1', 'field2'] _GROUP_BY2 = ['field3'] - query = self._makeOne(_KIND, _DATASET) + query = self._makeOne(self._makeClient()) query.group_by = _GROUP_BY1 self.assertEqual(query.group_by, _GROUP_BY1) query.group_by = _GROUP_BY2 self.assertEqual(query.group_by, _GROUP_BY2) - def test_fetch_defaults_wo_implicit_connection(self): - _DATASET = 'DATASET' - _KIND = 'KIND' - query = self._makeOne(_KIND, _DATASET) - self.assertRaises(ValueError, query.fetch) - - def test_fetch_defaults_w_implicit_connection(self): - from gcloud.datastore._testing import _monkey_defaults - - _DATASET = 'DATASET' - _KIND = 'KIND' + def test_fetch_defaults_w_client_attr(self): connection = _Connection() - query = self._makeOne(_KIND, _DATASET) - - with _monkey_defaults(connection=connection): - iterator = query.fetch() + client = self._makeClient(connection) + query = self._makeOne(client) + iterator = query.fetch() self.assertTrue(iterator._query is query) + self.assertTrue(iterator._client is client) self.assertEqual(iterator._limit, None) self.assertEqual(iterator._offset, 0) - def test_fetch_explicit(self): - _DATASET = 'DATASET' - _KIND = 'KIND' + def test_fetch_w_explicit_client(self): connection = _Connection() - query = self._makeOne(_KIND, _DATASET) - iterator = query.fetch(limit=7, offset=8, connection=connection) + client = self._makeClient(connection) + other_client = self._makeClient(connection) + query = self._makeOne(client) + iterator = query.fetch(limit=7, offset=8, client=other_client) self.assertTrue(iterator._query is query) + self.assertTrue(iterator._client is other_client) self.assertEqual(iterator._limit, 7) self.assertEqual(iterator._offset, 8) @@ -405,6 +341,11 @@ def _addQueryResults(self, connection, cursor=_END, more=False): connection._results.append( ([entity_pb], cursor, MORE if more else NO_MORE)) + def _makeClient(self, connection=None): + if connection is None: + connection = _Connection() + return _Client(self._DATASET, connection) + def test_ctor_defaults(self): connection = _Connection() query = object() @@ -414,9 +355,9 @@ def test_ctor_defaults(self): self.assertEqual(iterator._offset, 0) def test_ctor_explicit(self): - connection = _Connection() - query = _Query() - iterator = self._makeOne(query, connection, 13, 29) + client = self._makeClient() + query = _Query(client) + iterator = self._makeOne(query, client, 13, 29) self.assertTrue(iterator._query is query) self.assertEqual(iterator._limit, 13) self.assertEqual(iterator._offset, 29) @@ -425,9 +366,10 @@ def test_next_page_no_cursors_no_more(self): from base64 import b64encode from gcloud.datastore.query import _pb_from_query connection = _Connection() - query = _Query(self._DATASET, self._KIND, self._NAMESPACE) + client = self._makeClient(connection) + query = _Query(client, self._KIND, self._DATASET, self._NAMESPACE) self._addQueryResults(connection) - iterator = self._makeOne(query, connection) + iterator = self._makeOne(query, client) entities, more_results, cursor = iterator.next_page() self.assertEqual(cursor, b64encode(self._END)) @@ -451,9 +393,10 @@ def test_next_page_no_cursors_no_more_w_offset_and_limit(self): from base64 import b64encode from gcloud.datastore.query import _pb_from_query connection = _Connection() - query = _Query(self._DATASET, self._KIND, self._NAMESPACE) + client = self._makeClient(connection) + query = _Query(client, self._KIND, self._DATASET, self._NAMESPACE) self._addQueryResults(connection) - iterator = self._makeOne(query, connection, 13, 29) + iterator = self._makeOne(query, client, 13, 29) entities, more_results, cursor = iterator.next_page() self.assertEqual(cursor, b64encode(self._END)) @@ -479,9 +422,10 @@ def test_next_page_w_cursors_w_more(self): from base64 import b64encode from gcloud.datastore.query import _pb_from_query connection = _Connection() - query = _Query(self._DATASET, self._KIND, self._NAMESPACE) + client = self._makeClient(connection) + query = _Query(client, self._KIND, self._DATASET, self._NAMESPACE) self._addQueryResults(connection, cursor=self._END, more=True) - iterator = self._makeOne(query, connection) + iterator = self._makeOne(query, client) iterator._start_cursor = self._START iterator._end_cursor = self._END entities, more_results, cursor = iterator.next_page() @@ -509,19 +453,21 @@ def test_next_page_w_cursors_w_more(self): def test_next_page_w_cursors_w_bogus_more(self): connection = _Connection() - query = _Query(self._DATASET, self._KIND, self._NAMESPACE) + client = self._makeClient(connection) + query = _Query(client, self._KIND, self._DATASET, self._NAMESPACE) self._addQueryResults(connection, cursor=self._END, more=True) epb, cursor, _ = connection._results.pop() connection._results.append((epb, cursor, 4)) # invalid enum - iterator = self._makeOne(query, connection) + iterator = self._makeOne(query, client) self.assertRaises(ValueError, iterator.next_page) def test___iter___no_more(self): from gcloud.datastore.query import _pb_from_query connection = _Connection() - query = _Query(self._DATASET, self._KIND, self._NAMESPACE) + client = self._makeClient(connection) + query = _Query(client, self._KIND, self._DATASET, self._NAMESPACE) self._addQueryResults(connection) - iterator = self._makeOne(query, connection) + iterator = self._makeOne(query, client) entities = list(iterator) self.assertFalse(iterator._more_results) @@ -542,10 +488,11 @@ def test___iter___no_more(self): def test___iter___w_more(self): from gcloud.datastore.query import _pb_from_query connection = _Connection() - query = _Query(self._DATASET, self._KIND, self._NAMESPACE) + client = self._makeClient(connection) + query = _Query(client, self._KIND, self._DATASET, self._NAMESPACE) self._addQueryResults(connection, cursor=self._END, more=True) self._addQueryResults(connection) - iterator = self._makeOne(query, connection) + iterator = self._makeOne(query, client) entities = list(iterator) self.assertFalse(iterator._more_results) @@ -673,16 +620,18 @@ def test_group_by(self): class _Query(object): def __init__(self, - dataset_id=None, + client=object(), kind=None, + dataset_id=None, namespace=None, ancestor=None, filters=(), projection=(), order=(), group_by=()): - self.dataset_id = dataset_id + self._client = client self.kind = kind + self.dataset_id = dataset_id self.namespace = namespace self.ancestor = ancestor self.filters = filters @@ -705,3 +654,15 @@ def run_query(self, **kw): self._called_with.append(kw) result, self._results = self._results[0], self._results[1:] return result + + +class _Client(object): + + def __init__(self, dataset_id, connection, namespace=None): + self.dataset_id = dataset_id + self.connection = connection + self.namespace = namespace + + @property + def current_transaction(self): + pass diff --git a/gcloud/datastore/test_transaction.py b/gcloud/datastore/test_transaction.py index 8376425d964a..7d0f822a2e3a 100644 --- a/gcloud/datastore/test_transaction.py +++ b/gcloud/datastore/test_transaction.py @@ -17,41 +17,20 @@ class TestTransaction(unittest2.TestCase): - def setUp(self): - from gcloud.datastore._testing import _setup_defaults - _setup_defaults(self) - - def tearDown(self): - from gcloud.datastore._testing import _tear_down_defaults - _tear_down_defaults(self) - def _getTargetClass(self): from gcloud.datastore.transaction import Transaction - return Transaction - def _makeOne(self, dataset_id=None, connection=None): - return self._getTargetClass()(dataset_id=dataset_id, - connection=connection) - - def test_ctor_missing_required(self): - from gcloud.datastore import _implicit_environ - - self.assertEqual(_implicit_environ.get_default_dataset_id(), None) - - with self.assertRaises(ValueError): - self._makeOne() - with self.assertRaises(ValueError): - self._makeOne(dataset_id=object()) - with self.assertRaises(ValueError): - self._makeOne(connection=object()) + def _makeOne(self, client): + return self._getTargetClass()(client) def test_ctor(self): from gcloud.datastore._datastore_v1_pb2 import Mutation _DATASET = 'DATASET' connection = _Connection() - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) self.assertEqual(xact.dataset_id, _DATASET) self.assertEqual(xact.connection, connection) self.assertEqual(xact.id, None) @@ -59,37 +38,25 @@ def test_ctor(self): self.assertTrue(isinstance(xact.mutation, Mutation)) self.assertEqual(len(xact._auto_id_entities), 0) - def test_ctor_with_env(self): - from gcloud.datastore._testing import _monkey_defaults - - CONNECTION = _Connection() - DATASET_ID = 'DATASET' - with _monkey_defaults(connection=CONNECTION, dataset_id=DATASET_ID): - xact = self._makeOne() - - self.assertEqual(xact.id, None) - self.assertEqual(xact.dataset_id, DATASET_ID) - self.assertEqual(xact.connection, CONNECTION) - self.assertEqual(xact._status, self._getTargetClass()._INITIAL) - def test_current(self): from gcloud.datastore.test_client import _NoCommitBatch _DATASET = 'DATASET' connection = _Connection() - xact1 = self._makeOne(_DATASET, connection) - xact2 = self._makeOne(_DATASET, connection) + client = _Client(_DATASET, connection) + xact1 = self._makeOne(client) + xact2 = self._makeOne(client) self.assertTrue(xact1.current() is None) self.assertTrue(xact2.current() is None) with xact1: self.assertTrue(xact1.current() is xact1) self.assertTrue(xact2.current() is xact1) - with _NoCommitBatch(_DATASET, _Connection): + with _NoCommitBatch(client): self.assertTrue(xact1.current() is None) self.assertTrue(xact2.current() is None) with xact2: self.assertTrue(xact1.current() is xact2) self.assertTrue(xact2.current() is xact2) - with _NoCommitBatch(_DATASET, _Connection): + with _NoCommitBatch(client): self.assertTrue(xact1.current() is None) self.assertTrue(xact2.current() is None) self.assertTrue(xact1.current() is xact1) @@ -100,7 +67,8 @@ def test_current(self): def test_begin(self): _DATASET = 'DATASET' connection = _Connection(234) - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) xact.begin() self.assertEqual(xact.id, 234) self.assertEqual(connection._begun, _DATASET) @@ -108,7 +76,8 @@ def test_begin(self): def test_begin_tombstoned(self): _DATASET = 'DATASET' connection = _Connection(234) - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) xact.begin() self.assertEqual(xact.id, 234) self.assertEqual(connection._begun, _DATASET) @@ -121,7 +90,8 @@ def test_begin_tombstoned(self): def test_rollback(self): _DATASET = 'DATASET' connection = _Connection(234) - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) xact.begin() xact.rollback() self.assertEqual(xact.id, None) @@ -130,7 +100,8 @@ def test_rollback(self): def test_commit_no_auto_ids(self): _DATASET = 'DATASET' connection = _Connection(234) - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) xact._mutation = mutation = object() xact.begin() xact.commit() @@ -144,7 +115,8 @@ def test_commit_w_auto_ids(self): connection = _Connection(234) connection._commit_result = _CommitResult( _make_key(_KIND, _ID, _DATASET)) - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) entity = _Entity() xact.add_auto_id_entity(entity) xact._mutation = mutation = object() @@ -157,7 +129,8 @@ def test_commit_w_auto_ids(self): def test_context_manager_no_raise(self): _DATASET = 'DATASET' connection = _Connection(234) - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) xact._mutation = mutation = object() with xact: self.assertEqual(xact.id, 234) @@ -166,11 +139,14 @@ def test_context_manager_no_raise(self): self.assertEqual(xact.id, None) def test_context_manager_w_raise(self): + class Foo(Exception): pass + _DATASET = 'DATASET' connection = _Connection(234) - xact = self._makeOne(dataset_id=_DATASET, connection=connection) + client = _Client(_DATASET, connection) + xact = self._makeOne(client) xact._mutation = object() try: with xact: @@ -226,3 +202,23 @@ class _Entity(object): def __init__(self): from gcloud.datastore.key import Key self.key = Key('KIND', dataset_id='DATASET') + + +class _Client(object): + + def __init__(self, dataset_id, connection, namespace=None): + self.dataset_id = dataset_id + self.connection = connection + self.namespace = namespace + self._batches = [] + + def _push_batch(self, batch): + self._batches.insert(0, batch) + + def _pop_batch(self): + return self._batches.pop(0) + + @property + def current_batch(self): + if self._batches: + return self._batches[0] diff --git a/gcloud/datastore/transaction.py b/gcloud/datastore/transaction.py index bbc88b941936..5a60f0bdd382 100644 --- a/gcloud/datastore/transaction.py +++ b/gcloud/datastore/transaction.py @@ -84,14 +84,8 @@ class Transaction(Batch): ... else: ... transaction.commit() - :type dataset_id: string - :param dataset_id: The ID of the dataset. - - :type connection: :class:`gcloud.datastore.connection.Connection` - :param connection: The connection used to connect to datastore. - - :raises: :class:`ValueError` if either a connection or dataset ID - are not set. + :type client: :class:`gcloud.datastore.client.Client` + :param client: The client used to connect to datastore. """ _INITIAL = 0 @@ -106,8 +100,8 @@ class Transaction(Batch): _FINISHED = 3 """Enum value for _FINISHED status of transaction.""" - def __init__(self, dataset_id=None, connection=None): - super(Transaction, self).__init__(dataset_id, connection) + def __init__(self, client): + super(Transaction, self).__init__(client) self._id = None self._status = self._INITIAL @@ -120,8 +114,7 @@ def id(self): """ return self._id - @staticmethod - def current(): + def current(self): """Return the topmost transaction. .. note:: if the topmost element on the stack is not a transaction, @@ -129,7 +122,7 @@ def current(): :rtype: :class:`gcloud.datastore.transaction.Transaction` or None """ - top = Batch.current() + top = super(Transaction, self).current() if isinstance(top, Transaction): return top @@ -145,7 +138,7 @@ def begin(self): if self._status != self._INITIAL: raise ValueError('Transaction already started previously.') self._status = self._IN_PROGRESS - self._id = self.connection.begin_transaction(self._dataset_id) + self._id = self.connection.begin_transaction(self.dataset_id) def rollback(self): """Rolls back the current transaction. @@ -156,7 +149,7 @@ def rollback(self): - Sets the current transaction's ID to None. """ try: - self.connection.rollback(self._dataset_id, self._id) + self.connection.rollback(self.dataset_id, self._id) finally: self._status = self._ABORTED # Clear our own ID in case this gets accidentally reused.