diff --git a/datastore/google/cloud/datastore/_http.py b/datastore/google/cloud/datastore/_http.py index ac9059ff0340..2f0b27bde952 100644 --- a/datastore/google/cloud/datastore/_http.py +++ b/datastore/google/cloud/datastore/_http.py @@ -426,10 +426,8 @@ def commit(self, project, request, transaction_id): This method will mutate ``request`` before using it. - :rtype: tuple - :returns: The pair of the number of index updates and a list of - :class:`.entity_pb2.Key` for each incomplete key - that was completed in the commit. + :rtype: :class:`.datastore_pb2.CommitResponse` + :returns: The protobuf response from a commit request. """ if transaction_id: request.mode = _datastore_pb2.CommitRequest.TRANSACTIONAL @@ -437,8 +435,7 @@ def commit(self, project, request, transaction_id): else: request.mode = _datastore_pb2.CommitRequest.NON_TRANSACTIONAL - response = self._datastore_api.commit(project, request) - return _parse_commit_response(response) + return self._datastore_api.commit(project, request) def rollback(self, project, transaction_id): """Rollback the connection's existing transaction. @@ -508,21 +505,3 @@ def _add_keys_to_request(request_field_pb, key_pbs): """ for key_pb in key_pbs: request_field_pb.add().CopyFrom(key_pb) - - -def _parse_commit_response(commit_response_pb): - """Extract response data from a commit response. - - :type commit_response_pb: :class:`.datastore_pb2.CommitResponse` - :param commit_response_pb: The protobuf response from a commit request. - - :rtype: tuple - :returns: The pair of the number of index updates and a list of - :class:`.entity_pb2.Key` for each incomplete key - that was completed in the commit. - """ - mut_results = commit_response_pb.mutation_results - index_updates = commit_response_pb.index_updates - completed_keys = [mut_result.key for mut_result in mut_results - if mut_result.HasField('key')] # Message field (Key) - return index_updates, completed_keys diff --git a/datastore/google/cloud/datastore/batch.py b/datastore/google/cloud/datastore/batch.py index 00854d2007b6..484af5a67c76 100644 --- a/datastore/google/cloud/datastore/batch.py +++ b/datastore/google/cloud/datastore/batch.py @@ -238,8 +238,9 @@ def _commit(self): This is called by :meth:`commit`. """ # NOTE: ``self._commit_request`` will be modified. - _, updated_keys = self._client._connection.commit( + commit_response_pb = self._client._connection.commit( self.project, self._commit_request, self._id) + _, updated_keys = _parse_commit_response(commit_response_pb) # If the back-end returns without error, we are guaranteed that # :meth:`Connection.commit` will return keys that match (length and # order) directly ``_partial_key_entities``. @@ -311,3 +312,21 @@ def _assign_entity_to_pb(entity_pb, entity): bare_entity_pb = helpers.entity_to_protobuf(entity) bare_entity_pb.key.CopyFrom(bare_entity_pb.key) entity_pb.CopyFrom(bare_entity_pb) + + +def _parse_commit_response(commit_response_pb): + """Extract response data from a commit response. + + :type commit_response_pb: :class:`.datastore_pb2.CommitResponse` + :param commit_response_pb: The protobuf response from a commit request. + + :rtype: tuple + :returns: The pair of the number of index updates and a list of + :class:`.entity_pb2.Key` for each incomplete key + that was completed in the commit. + """ + mut_results = commit_response_pb.mutation_results + index_updates = commit_response_pb.index_updates + completed_keys = [mut_result.key for mut_result in mut_results + if mut_result.HasField('key')] # Message field (Key) + return index_updates, completed_keys diff --git a/datastore/unit_tests/test__http.py b/datastore/unit_tests/test__http.py index 04087c7122c6..690cc45f325e 100644 --- a/datastore/unit_tests/test__http.py +++ b/datastore/unit_tests/test__http.py @@ -696,8 +696,8 @@ def test_commit_wo_transaction(self): from google.cloud.grpc.datastore.v1 import datastore_pb2 from google.cloud.datastore.helpers import _new_value_pb - PROJECT = 'PROJECT' - key_pb = self._make_key_pb(PROJECT) + project = 'PROJECT' + key_pb = self._make_key_pb(project) rsp_pb = datastore_pb2.CommitResponse() req_pb = datastore_pb2.CommitRequest() mutation = req_pb.mutations.add() @@ -708,44 +708,32 @@ def test_commit_wo_transaction(self): http = Http({'status': '200'}, rsp_pb.SerializeToString()) client = mock.Mock(_http=http, spec=['_http']) conn = self._make_one(client) - URI = '/'.join([ + uri = '/'.join([ conn.api_base_url, conn.API_VERSION, 'projects', - PROJECT + ':commit', + project + ':commit', ]) - # Set up mock for parsing the response. - expected_result = object() - _parsed = [] - - def mock_parse(response): - _parsed.append(response) - return expected_result - - patch = mock.patch( - 'google.cloud.datastore._http._parse_commit_response', - new=mock_parse) - with patch: - result = conn.commit(PROJECT, req_pb, None) + result = conn.commit(project, req_pb, None) + self.assertEqual(result, rsp_pb) - self.assertIs(result, expected_result) + # Verify the caller. cw = http._called_with - self._verify_protobuf_call(cw, URI, conn) + self._verify_protobuf_call(cw, uri, conn) rq_class = datastore_pb2.CommitRequest request = rq_class() request.ParseFromString(cw['body']) self.assertEqual(request.transaction, b'') self.assertEqual(list(request.mutations), [mutation]) self.assertEqual(request.mode, rq_class.NON_TRANSACTIONAL) - self.assertEqual(_parsed, [rsp_pb]) def test_commit_w_transaction(self): from google.cloud.grpc.datastore.v1 import datastore_pb2 from google.cloud.datastore.helpers import _new_value_pb - PROJECT = 'PROJECT' - key_pb = self._make_key_pb(PROJECT) + project = 'PROJECT' + key_pb = self._make_key_pb(project) rsp_pb = datastore_pb2.CommitResponse() req_pb = datastore_pb2.CommitRequest() mutation = req_pb.mutations.add() @@ -756,37 +744,25 @@ def test_commit_w_transaction(self): http = Http({'status': '200'}, rsp_pb.SerializeToString()) client = mock.Mock(_http=http, spec=['_http']) conn = self._make_one(client) - URI = '/'.join([ + uri = '/'.join([ conn.api_base_url, conn.API_VERSION, 'projects', - PROJECT + ':commit', + project + ':commit', ]) - # Set up mock for parsing the response. - expected_result = object() - _parsed = [] + result = conn.commit(project, req_pb, b'xact') + self.assertEqual(result, rsp_pb) - def mock_parse(response): - _parsed.append(response) - return expected_result - - patch = mock.patch( - 'google.cloud.datastore._http._parse_commit_response', - new=mock_parse) - with patch: - result = conn.commit(PROJECT, req_pb, b'xact') - - self.assertIs(result, expected_result) + # Verify the caller. cw = http._called_with - self._verify_protobuf_call(cw, URI, conn) + self._verify_protobuf_call(cw, uri, conn) rq_class = datastore_pb2.CommitRequest request = rq_class() request.ParseFromString(cw['body']) self.assertEqual(request.transaction, b'xact') self.assertEqual(list(request.mutations), [mutation]) self.assertEqual(request.mode, rq_class.TRANSACTIONAL) - self.assertEqual(_parsed, [rsp_pb]) def test_rollback_ok(self): from google.cloud.grpc.datastore.v1 import datastore_pb2 @@ -870,46 +846,6 @@ def test_allocate_ids_non_empty(self): self.assertEqual(key_before, key_after) -class Test__parse_commit_response(unittest.TestCase): - - def _call_fut(self, commit_response_pb): - from google.cloud.datastore._http import _parse_commit_response - - return _parse_commit_response(commit_response_pb) - - def test_it(self): - from google.cloud.grpc.datastore.v1 import datastore_pb2 - from google.cloud.grpc.datastore.v1 import entity_pb2 - - index_updates = 1337 - keys = [ - entity_pb2.Key( - path=[ - entity_pb2.Key.PathElement( - kind='Foo', - id=1234, - ), - ], - ), - entity_pb2.Key( - path=[ - entity_pb2.Key.PathElement( - kind='Bar', - name='baz', - ), - ], - ), - ] - response = datastore_pb2.CommitResponse( - mutation_results=[ - datastore_pb2.MutationResult(key=key) for key in keys - ], - index_updates=index_updates, - ) - result = self._call_fut(response) - self.assertEqual(result, (index_updates, keys)) - - class Http(object): _called_with = None diff --git a/datastore/unit_tests/test_batch.py b/datastore/unit_tests/test_batch.py index 72614d070cba..f2c54c680bef 100644 --- a/datastore/unit_tests/test_batch.py +++ b/datastore/unit_tests/test_batch.py @@ -249,13 +249,13 @@ def test_commit_wrong_status(self): self.assertRaises(ValueError, batch.commit) def test_commit_w_partial_key_entities(self): - _PROJECT = 'PROJECT' - _NEW_ID = 1234 - connection = _Connection(_NEW_ID) - client = _Client(_PROJECT, connection) + project = 'PROJECT' + new_id = 1234 + connection = _Connection(new_id) + client = _Client(project, connection) batch = self._make_one(client) entity = _Entity({}) - key = entity.key = _Key(_PROJECT) + key = entity.key = _Key(project) key._id = None batch._partial_key_entities.append(entity) @@ -266,9 +266,9 @@ def test_commit_w_partial_key_entities(self): self.assertEqual(batch._status, batch._FINISHED) self.assertEqual(connection._committed, - [(_PROJECT, batch._commit_request, None)]) + [(project, batch._commit_request, None)]) self.assertFalse(entity.key.is_partial) - self.assertEqual(entity.key._id, _NEW_ID) + self.assertEqual(entity.key._id, new_id) def test_as_context_mgr_wo_error(self): _PROJECT = 'PROJECT' @@ -369,30 +369,62 @@ def begin(self): self.assertEqual(client._batches, []) -class _PathElementPB(object): +class Test__parse_commit_response(unittest.TestCase): - def __init__(self, id_): - self.id = id_ + def _call_fut(self, commit_response_pb): + from google.cloud.datastore.batch import _parse_commit_response + return _parse_commit_response(commit_response_pb) -class _KeyPB(object): + def test_it(self): + from google.cloud.grpc.datastore.v1 import datastore_pb2 + from google.cloud.grpc.datastore.v1 import entity_pb2 - def __init__(self, id_): - self.path = [_PathElementPB(id_)] + index_updates = 1337 + keys = [ + entity_pb2.Key( + path=[ + entity_pb2.Key.PathElement( + kind='Foo', + id=1234, + ), + ], + ), + entity_pb2.Key( + path=[ + entity_pb2.Key.PathElement( + kind='Bar', + name='baz', + ), + ], + ), + ] + response = datastore_pb2.CommitResponse( + mutation_results=[ + datastore_pb2.MutationResult(key=key) for key in keys + ], + index_updates=index_updates, + ) + result = self._call_fut(response) + self.assertEqual(result, (index_updates, keys)) class _Connection(object): _marker = object() _save_result = (False, None) - def __init__(self, *new_keys): - self._completed_keys = [_KeyPB(key) for key in new_keys] + def __init__(self, *new_key_ids): + from google.cloud.grpc.datastore.v1 import datastore_pb2 + self._committed = [] - self._index_updates = 0 + mutation_results = [ + _make_mutation(key_id) for key_id in new_key_ids] + self._commit_response_pb = datastore_pb2.CommitResponse( + mutation_results=mutation_results) def commit(self, project, commit_request, transaction_id): self._committed.append((project, commit_request, transaction_id)) - return self._index_updates, self._completed_keys + return self._commit_response_pb class _Entity(dict): @@ -472,3 +504,15 @@ def _mutated_pb(test_case, mutation_pb_list, mutation_type): mutation_type) return getattr(mutated_pb, mutation_type) + + +def _make_mutation(id_): + from google.cloud.grpc.datastore.v1 import datastore_pb2 + from google.cloud.grpc.datastore.v1 import entity_pb2 + + key = entity_pb2.Key() + key.partition_id.project_id = 'PROJECT' + elem = key.path.add() + elem.kind = 'Kind' + elem.id = id_ + return datastore_pb2.MutationResult(key=key) diff --git a/datastore/unit_tests/test_client.py b/datastore/unit_tests/test_client.py index e89acc9c0922..0102abd7ca02 100644 --- a/datastore/unit_tests/test_client.py +++ b/datastore/unit_tests/test_client.py @@ -568,7 +568,8 @@ def test_put_multi_no_batch_w_partial_key(self): creds = _make_credentials() client = self._make_one(credentials=creds) - client._connection._commit.append([_KeyPB(key)]) + key_pb = _make_key(234) + client._connection._commit.append([key_pb]) result = client.put_multi([entity]) self.assertIsNone(result) @@ -931,7 +932,6 @@ def __init__(self, credentials=None, http=None): self._commit = [] self._alloc_cw = [] self._alloc = [] - self._index_updates = 0 def _add_lookup_result(self, results=(), missing=(), deferred=()): self._lookup.append((list(results), list(missing), list(deferred))) @@ -943,9 +943,13 @@ def lookup(self, project, key_pbs, eventual=False, transaction_id=None): return results, missing, deferred def commit(self, project, commit_request, transaction_id): + from google.cloud.grpc.datastore.v1 import datastore_pb2 + self._commit_cw.append((project, commit_request, transaction_id)) - response, self._commit = self._commit[0], self._commit[1:] - return self._index_updates, response + keys, self._commit = self._commit[0], self._commit[1:] + mutation_results = [ + datastore_pb2.MutationResult(key=key) for key in keys] + return datastore_pb2.CommitResponse(mutation_results=mutation_results) def allocate_ids(self, project, key_pbs): self._alloc_cw.append((project, key_pbs)) @@ -1058,3 +1062,12 @@ def _mutated_pb(test_case, mutation_pb_list, mutation_type): mutation_type) return getattr(mutated_pb, mutation_type) + + +def _make_key(id_): + from google.cloud.grpc.datastore.v1 import entity_pb2 + + key = entity_pb2.Key() + elem = key.path.add() + elem.id = id_ + return key diff --git a/datastore/unit_tests/test_transaction.py b/datastore/unit_tests/test_transaction.py index 7aa295bf7fca..4dec32420356 100644 --- a/datastore/unit_tests/test_transaction.py +++ b/datastore/unit_tests/test_transaction.py @@ -126,12 +126,12 @@ def test_commit_no_partial_keys(self): self.assertIsNone(xact.id) def test_commit_w_partial_keys(self): - _PROJECT = 'PROJECT' - _KIND = 'KIND' - _ID = 123 - connection = _Connection(234) - connection._completed_keys = [_make_key(_KIND, _ID, _PROJECT)] - client = _Client(_PROJECT, connection) + project = 'PROJECT' + kind = 'KIND' + id_ = 123 + key = _make_key(kind, id_, project) + connection = _Connection(234, keys=[key]) + client = _Client(project, connection) xact = self._make_one(client) xact.begin() entity = _Entity() @@ -139,9 +139,9 @@ def test_commit_w_partial_keys(self): xact._commit_request = commit_request = object() xact.commit() self.assertEqual(connection._committed, - (_PROJECT, commit_request, 234)) + (project, commit_request, 234)) self.assertIsNone(xact.id) - self.assertEqual(entity.key.path, [{'kind': _KIND, 'id': _ID}]) + self.assertEqual(entity.key.path, [{'kind': kind, 'id': id_}]) def test_context_manager_no_raise(self): _PROJECT = 'PROJECT' @@ -196,10 +196,14 @@ class _Connection(object): _committed = None _side_effect = None - def __init__(self, xact_id=123): + def __init__(self, xact_id=123, keys=()): + from google.cloud.grpc.datastore.v1 import datastore_pb2 + self._xact_id = xact_id - self._completed_keys = [] - self._index_updates = 0 + mutation_results = [ + datastore_pb2.MutationResult(key=key) for key in keys] + self._commit_response_pb = datastore_pb2.CommitResponse( + mutation_results=mutation_results) def begin_transaction(self, project): self._begun = project @@ -213,7 +217,7 @@ def rollback(self, project, transaction_id): def commit(self, project, commit_request, transaction_id): self._committed = (project, commit_request, transaction_id) - return self._index_updates, self._completed_keys + return self._commit_response_pb class _Entity(dict):