diff --git a/google/cloud/firestore_v1/_helpers.py b/google/cloud/firestore_v1/_helpers.py index e98ec8547c..6dbabaf1d0 100644 --- a/google/cloud/firestore_v1/_helpers.py +++ b/google/cloud/firestore_v1/_helpers.py @@ -484,7 +484,6 @@ def get_update_pb( current_document = common.Precondition(exists=exists) else: current_document = None - update_pb = write.Write( update=document.Document( name=document_path, fields=encode_dict(self.set_fields) @@ -649,7 +648,6 @@ def has_updates(self): # of nested transform paths in the update mask # (see set-st-merge-nonleaf-alone.textproto) update_paths = set(self.data_merge) - for transform_path in self.transform_paths: if len(transform_path.parts) > 1: parent_fp = FieldPath(*transform_path.parts[:-1]) diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index c51084ac50..a3e8da15cf 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -1569,11 +1569,13 @@ def _add_field_transforms(update_pb, fields): ) ) - def _helper(self, do_transform=False, empty_val=False): + def _helper(self, do_transform=False, empty_val=False, set_field=True): from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"cheese": 1.5, "crackers": True} + document_data = {} + if set_field: + document_data = {"cheese": 1.5, "crackers": True} if do_transform: document_data["butter"] = SERVER_TIMESTAMP @@ -1592,7 +1594,6 @@ def _helper(self, do_transform=False, empty_val=False): document_path, cheese=1.5, crackers=True ) expected_pbs = [update_pb] - if do_transform: self._add_field_transforms(update_pb, fields=["butter"]) @@ -1607,6 +1608,23 @@ def test_w_transform(self): def test_w_transform_and_empty_value(self): self._helper(do_transform=True, empty_val=True) + def test_w_deleted_fields(self): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + document_data = { + "write_me": "value", + "delete_me": DELETE_FIELD, + "ignore_me": 123, + } + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + + with self.assertRaises(ValueError) as exc: + self._call_fut(document_path, document_data) + + self.assertEqual( + exc.exception.args, ("Cannot apply DELETE_FIELD in a create request.",) + ) + class Test_pbs_for_set_no_merge(unittest.TestCase): @staticmethod @@ -1699,6 +1717,23 @@ def test_w_transform_and_empty_value(self): # Exercise #5944 self._helper(do_transform=True, empty_val=True) + def test_w_deleted_fields(self): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + document_data = { + "write_me": "value", + "delete_me": DELETE_FIELD, + "ignore_me": 123, + } + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + + with self.assertRaises(ValueError) as exc: + self._call_fut(document_path, document_data) + + self.assertIn( + "Cannot apply DELETE_FIELD in a set request", exc.exception.args[0] + ) + class TestDocumentExtractorForMerge(unittest.TestCase): @staticmethod @@ -1897,6 +1932,36 @@ def test_apply_merge_list_fields_w_array_union(self): self.assertEqual(inst.array_unions, expected_array_unions) self.assertTrue(inst.has_updates) + def test_apply_merge_multiple_transform_paths(self): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP, ArrayUnion + + values = [1, 3, 5] + document_data = { + "write_me": "value", + "b": {"timestamp": SERVER_TIMESTAMP}, + "union_me": ArrayUnion(values), + } + inst = self._make_one(document_data) + + inst.apply_merge(True) + + expected_data_merge = [_make_field_path("write_me")] + expected_transform_merge = [ + _make_field_path("b", "timestamp"), + _make_field_path("union_me"), + ] + expected_merge = [ + _make_field_path("b", "timestamp"), + _make_field_path("union_me"), + _make_field_path("write_me"), + ] + self.assertEqual(inst.data_merge, expected_data_merge) + self.assertEqual(inst.transform_merge, expected_transform_merge) + self.assertEqual(inst.merge, expected_merge) + expected_array_unions = {_make_field_path("union_me"): values} + self.assertEqual(inst.array_unions, expected_array_unions) + self.assertTrue(inst.has_updates) + class Test_pbs_for_set_with_merge(unittest.TestCase): @staticmethod @@ -2110,6 +2175,24 @@ def test_ctor_w_nested_dotted_keys(self): self.assertEqual(inst.top_level_paths, expected_paths) self.assertEqual(inst.set_fields, expected_set_fields) + def test_ctor_w_dotted_keys_conflict(self): + document_data = {"a.d": {"h.i": 9}, "a.d.e": 1, "b.f": 7, "c": 3} + + with self.assertRaises(ValueError) as exc: + self._make_one(document_data) + + self.assertIn("Conflicting field path", exc.exception.args[0]) + + def test_ctor_w_dotted_keys_deleted_fields(self): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + document_data = {"a.d.e": {"h.i": DELETE_FIELD}, "b": 2, "c": 3} + + with self.assertRaises(ValueError) as exc: + self._make_one(document_data) + + self.assertIn("Cannot update with nest delete", exc.exception.args[0]) + class Test_pbs_for_update(unittest.TestCase): @staticmethod diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 770d6ae204..0ea84d7e98 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -216,15 +216,15 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = [c async for c in client.collections()] + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): self.assertIsInstance(collection, AsyncCollectionReference) @@ -236,6 +236,52 @@ def _next_page(self): request={"parent": base_path}, metadata=client._rpc_metadata ) + @pytest.mark.asyncio + async def test_collections_w_next_page_token(self): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + + collection_ids = ["users", "projects"] + client = self._make_default_one() + firestore_api = AsyncMock() + firestore_api.mock_add_spec(spec=["list_collection_ids"]) + client._firestore_api_internal = firestore_api + + # TODO(microgen): list_collection_ids isn't a pager. + # https://github.com/googleapis/gapic-generator-python/issues/516 + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = "next_page_token" + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + iterator = _Iterator(pages=[collection_ids]) + iterator1 = _Iterator1(pages=[collection_ids]) + firestore_api.list_collection_ids.side_effect = [iterator, iterator1] + collections = [c async for c in client.collections()] + + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + self.assertEqual(len(collections), 4) + firestore_api.list_collection_ids.call_count = 2 + async def _get_all_helper(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["batch_get_documents"]) diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 79a89d4abb..a58768069f 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -424,7 +424,7 @@ async def test_get_with_transaction(self): await self._get_helper(use_transaction=True) @pytest.mark.asyncio - async def _collections_helper(self, page_size=None): + async def _collections_helper(self, page_size=None, page_token=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -437,37 +437,58 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = page_token + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) collection_ids = ["coll-1", "coll-2"] iterator = _Iterator(pages=[collection_ids]) firestore_api = AsyncMock() firestore_api.mock_add_spec(spec=["list_collection_ids"]) - firestore_api.list_collection_ids.return_value = iterator + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + + if page_token: + iterator1 = _Iterator1(pages=[collection_ids]) + firestore_api.list_collection_ids.side_effect = [iterator1, iterator] + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + else: + firestore_api.list_collection_ids.return_value = iterator client = _make_client() client._firestore_api_internal = firestore_api # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) - if page_size is not None: - collections = [c async for c in document.collections(page_size=page_size)] - else: - collections = [c async for c in document.collections()] - # Verify the response and the mocks. - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, AsyncCollectionReference) - self.assertEqual(collection.parent, document) - self.assertEqual(collection.id, collection_id) - - firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, - metadata=client._rpc_metadata, - ) + collections = [c async for c in document.collections(page_size=page_size)] + + if page_token: + self.assertEqual(len(collections), 4) + firestore_api.list_collection_ids.call_count = 2 + else: + # Verify the response and the mocks. + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, AsyncCollectionReference) + self.assertEqual(collection.parent, document) + self.assertEqual(collection.id, collection_id) + + firestore_api.list_collection_ids.assert_called_once_with( + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, + ) @pytest.mark.asyncio async def test_collections_wo_page_size(self): @@ -477,6 +498,10 @@ async def test_collections_wo_page_size(self): async def test_collections_w_page_size(self): await self._collections_helper(page_size=10) + @pytest.mark.asyncio + async def test_collections_w_next_page(self): + await self._collections_helper(page_token="next_page_token") + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 944c63ae02..fcc60c5e44 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -25,16 +25,6 @@ ) -class MockAsyncIter: - def __init__(self, count=3): - # count is arbitrary value - self.count = count - - async def __aiter__(self, **_): - for i in range(self.count): - yield i - - class TestAsyncQuery(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index 631733e075..21ddb55202 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -89,6 +89,31 @@ def test__firestore_api_property_with_emulator( self.assertIs(client._firestore_api, mock_client.return_value) self.assertEqual(mock_client.call_count, 1) + @mock.patch( + "google.cloud.firestore_v1.services.firestore.client.FirestoreClient", + autospec=True, + return_value=mock.sentinel.firestore_api, + ) + @mock.patch( + "google.cloud.firestore_v1.services.firestore.transports.grpc.FirestoreGrpcTransport.create_channel", + autospec=True, + ) + def test__firestore_api_property_with_default( + self, mock_insecure_channel, mock_client + ): + client = self._make_default_one() + self.assertIsNone(client._firestore_api_internal) + firestore_api = client._firestore_api + self.assertIs(firestore_api, mock_client.return_value) + self.assertIs(firestore_api, client._firestore_api_internal) + mock_client.assert_called_once_with( + transport=client._transport, client_options=None + ) + + # Call again to show that it is cached, but call count is still 1. + self.assertIs(client._firestore_api, mock_client.return_value) + self.assertEqual(mock_client.call_count, 1) + def test___database_string_property(self): credentials = _make_credentials() database = "cheeeeez" diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index b943fd1e14..61bd46fed8 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -212,15 +212,16 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = list(client.collections()) + collections = [c for c in client.collections()] + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): self.assertIsInstance(collection, CollectionReference) @@ -232,6 +233,50 @@ def _next_page(self): request={"parent": base_path}, metadata=client._rpc_metadata ) + def test_collections_w_next_page_token(self): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + + collection_ids = ["users", "projects"] + client = self._make_default_one() + firestore_api = mock.Mock(spec=["list_collection_ids"]) + client._firestore_api_internal = firestore_api + + # TODO(microgen): list_collection_ids isn't a pager. + # https://github.com/googleapis/gapic-generator-python/issues/516 + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = "next_page_token" + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + iterator = _Iterator(pages=[collection_ids]) + iterator1 = _Iterator1(pages=[collection_ids]) + firestore_api.list_collection_ids.side_effect = [iterator, iterator1] + collections = [c for c in client.collections()] + + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + self.assertEqual(len(collections), 4) + firestore_api.list_collection_ids.call_count = 2 + def _get_all_helper(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["batch_get_documents"]) diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index ff06532c4b..e1bd4d2f29 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -399,7 +399,7 @@ def test_get_with_multiple_field_paths(self): def test_get_with_transaction(self): self._get_helper(use_transaction=True) - def _collections_helper(self, page_size=None): + def _collections_helper(self, page_size=None, page_token=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.collection import CollectionReference @@ -413,36 +413,57 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = page_token + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) collection_ids = ["coll-1", "coll-2"] iterator = _Iterator(pages=[collection_ids]) api_client = mock.create_autospec(FirestoreClient) - api_client.list_collection_ids.return_value = iterator + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + + if page_token: + iterator1 = _Iterator1(pages=[collection_ids]) + api_client.list_collection_ids.side_effect = [iterator1, iterator] + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + else: + api_client.list_collection_ids.return_value = iterator client = _make_client() client._firestore_api_internal = api_client # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) - if page_size is not None: - collections = list(document.collections(page_size=page_size)) - else: - collections = list(document.collections()) - # Verify the response and the mocks. - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, CollectionReference) - self.assertEqual(collection.parent, document) - self.assertEqual(collection.id, collection_id) - - api_client.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, - metadata=client._rpc_metadata, - ) + collections = list(document.collections(page_size=page_size)) + + if page_token: + self.assertEqual(len(collections), 4) + api_client.list_collection_ids.call_count = 2 + else: + # Verify the response and the mocks. + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, CollectionReference) + self.assertEqual(collection.parent, document) + self.assertEqual(collection.id, collection_id) + + api_client.list_collection_ids.assert_called_once_with( + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, + ) def test_collections_wo_page_size(self): self._collections_helper() @@ -450,6 +471,9 @@ def test_collections_wo_page_size(self): def test_collections_w_page_size(self): self._collections_helper(page_size=10) + def test_collections_w_next_page(self): + self._collections_helper(page_token="next_page_token") + @mock.patch("google.cloud.firestore_v1.document.Watch", autospec=True) def test_on_snapshot(self, watch): client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) diff --git a/tests/unit/v1/test_order.py b/tests/unit/v1/test_order.py index 4db743221c..90d99e563e 100644 --- a/tests/unit/v1/test_order.py +++ b/tests/unit/v1/test_order.py @@ -207,8 +207,6 @@ def _int_value(value): def _string_value(s): - if not isinstance(s, str): - s = str(s) return encode_value(s) diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 759549b72a..d8a97a8af1 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -374,6 +374,18 @@ def test_on_snapshot_target_add(self): inst.on_snapshot(proto) self.assertEqual(str(exc.exception), "Unexpected target ID 1 sent by server") + def test_on_snapshot_target_add_successfully(self): + from google.cloud.firestore_v1.watch import WATCH_TARGET_ID + + inst = self._makeOne() + proto = DummyProto() + proto.target_change.target_change_type = ( + firestore.TargetChange.TargetChangeType.ADD + ) + proto.target_change.target_ids = [20601] # not "Py" + inst.on_snapshot(proto) + self.assertEqual(proto.target_change.target_ids[0], WATCH_TARGET_ID) + def test_on_snapshot_target_remove(self): inst = self._makeOne() proto = DummyProto() @@ -565,6 +577,23 @@ def test_on_snapshot_unknown_listen_type(self): str(exc.exception), ) + def test_on_snapshot_document_delete(self): + from google.cloud.firestore_v1.watch import ChangeType + + inst = self._makeOne() + proto = DummyProto() + proto.target_change = "" + proto.document_change = "" + + class DummyDelete(object): + document = "fred" + + delete = DummyDelete() + proto.document_remove = "" + proto.document_delete = delete + inst.on_snapshot(proto) + self.assertTrue(inst.change_map["fred"] is ChangeType.REMOVED) + def test_push_callback_called_no_changes(self): import pytz