From af85bba32af06046ae92310cb6edc4206bf56d7d Mon Sep 17 00:00:00 2001 From: HemangChothani Date: Tue, 22 Sep 2020 18:38:42 +0530 Subject: [PATCH] test: raise coverage --- tests/unit/v1/test__helpers.py | 209 ++++++++++++++++++++------- tests/unit/v1/test_async_client.py | 54 ++++++- tests/unit/v1/test_async_document.py | 65 ++++++--- tests/unit/v1/test_async_query.py | 10 -- tests/unit/v1/test_base_client.py | 25 ++++ tests/unit/v1/test_client.py | 53 ++++++- tests/unit/v1/test_document.py | 64 +++++--- tests/unit/v1/test_order.py | 2 - tests/unit/v1/test_watch.py | 29 ++++ 9 files changed, 399 insertions(+), 112 deletions(-) diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index 55b74f89dc..f11bfa21c6 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -1526,9 +1526,10 @@ def _make_write_w_document(document_path, **data): ) @staticmethod - def _make_write_w_transform(document_path, fields): + def _make_write_w_transform(document_path, fields, set_field=True): from google.cloud.firestore_v1.types import write from google.cloud.firestore_v1 import DocumentTransform + from google.cloud.firestore_v1.types import common server_val = DocumentTransform.FieldTransform.ServerValue transforms = [ @@ -1537,18 +1538,27 @@ def _make_write_w_transform(document_path, fields): ) for field in fields ] - - return write.Write( - transform=write.DocumentTransform( - document=document_path, field_transforms=transforms + if not set_field: + return write.Write( + transform=write.DocumentTransform( + document=document_path, field_transforms=transforms + ), + current_document=common.Precondition(exists=False), + ) + else: + return write.Write( + transform=write.DocumentTransform( + document=document_path, field_transforms=transforms + ), ) - ) - def _helper(self, do_transform=False, empty_val=False): + def _helper(self, do_transform=False, empty_val=False, set_field=True): from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"cheese": 1.5, "crackers": True} + document_data = {} + if set_field: + document_data = {"cheese": 1.5, "crackers": True} if do_transform: document_data["butter"] = SERVER_TIMESTAMP @@ -1558,19 +1568,25 @@ def _helper(self, do_transform=False, empty_val=False): write_pbs = self._call_fut(document_path, document_data) + expected_pbs = [] if empty_val: - update_pb = self._make_write_w_document( - document_path, cheese=1.5, crackers=True, mustard={} + expected_pbs.append( + self._make_write_w_document( + document_path, cheese=1.5, crackers=True, mustard={} + ) ) + elif not set_field: + pass else: - update_pb = self._make_write_w_document( - document_path, cheese=1.5, crackers=True + expected_pbs.append( + self._make_write_w_document(document_path, cheese=1.5, crackers=True) ) - expected_pbs = [update_pb] if do_transform: expected_pbs.append( - self._make_write_w_transform(document_path, fields=["butter"]) + self._make_write_w_transform( + document_path, fields=["butter"], set_field=set_field + ) ) self.assertEqual(write_pbs, expected_pbs) @@ -1584,6 +1600,26 @@ def test_w_transform(self): def test_w_transform_and_empty_value(self): self._helper(do_transform=True, empty_val=True) + def test_w_transform_and_set_fields(self): + self._helper(do_transform=True, set_field=False) + + def test_w_deleted_fields(self): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + document_data = { + "write_me": "value", + "delete_me": DELETE_FIELD, + "ignore_me": 123, + } + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + + with self.assertRaises(ValueError) as exc: + self._call_fut(document_path, document_data) + + self.assertEqual( + exc.exception.args, ("Cannot apply DELETE_FIELD in a create request.",) + ) + class Test_pbs_for_set_no_merge(unittest.TestCase): @staticmethod @@ -1685,6 +1721,23 @@ def test_w_transform_and_empty_value(self): # Exercise #5944 self._helper(do_transform=True, empty_val=True) + def test_w_deleted_fields(self): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + document_data = { + "write_me": "value", + "delete_me": DELETE_FIELD, + "ignore_me": 123, + } + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + + with self.assertRaises(ValueError) as exc: + self._call_fut(document_path, document_data) + + self.assertIn( + "Cannot apply DELETE_FIELD in a set request", exc.exception.args[0] + ) + class TestDocumentExtractorForMerge(unittest.TestCase): @staticmethod @@ -1973,22 +2026,14 @@ def test_with_merge_true_w_transform(self): def test_with_merge_field_w_transform(self): from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + document_data = {} document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - update_data = {"cheese": 1.5, "crackers": True} - document_data = update_data.copy() document_data["butter"] = SERVER_TIMESTAMP - write_pbs = self._call_fut( - document_path, document_data, merge=["cheese", "butter"] - ) + write_pbs = self._call_fut(document_path, document_data, merge=["butter"]) - update_pb = self._make_write_w_document( - document_path, cheese=document_data["cheese"] - ) - self._update_document_mask(update_pb, ["cheese"]) transform_pb = self._make_write_w_transform(document_path, fields=["butter"]) - expected_pbs = [update_pb, transform_pb] - self.assertEqual(write_pbs, expected_pbs) + self.assertEqual(write_pbs, [transform_pb]) def test_with_merge_field_w_transform_masking_simple(self): from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP @@ -2029,6 +2074,26 @@ def test_with_merge_field_w_transform_parent(self): expected_pbs = [update_pb, transform_pb] self.assertEqual(write_pbs, expected_pbs) + def test_wo_merge_field_w_transform(self): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + update_data = {"cheese": 1.5, "crackers": True} + document_data = update_data.copy() + document_data["butter"] = SERVER_TIMESTAMP + + write_pbs = self._call_fut( + document_path, document_data, merge=["cheese", "butter"] + ) + + update_pb = self._make_write_w_document( + document_path, cheese=document_data["cheese"] + ) + self._update_document_mask(update_pb, ["cheese"]) + transform_pb = self._make_write_w_transform(document_path, fields=["butter"]) + expected_pbs = [update_pb, transform_pb] + self.assertEqual(write_pbs, expected_pbs) + class TestDocumentExtractorForUpdate(unittest.TestCase): @staticmethod @@ -2092,6 +2157,24 @@ def test_ctor_w_nested_dotted_keys(self): self.assertEqual(inst.top_level_paths, expected_paths) self.assertEqual(inst.set_fields, expected_set_fields) + def test_ctor_w_dotted_keys_conflict(self): + document_data = {"a.d": {"h.i": 9}, "a.d.e": 1, "b.f": 7, "c": 3} + + with self.assertRaises(ValueError) as exc: + self._make_one(document_data) + + self.assertIn("Conflicting field path", exc.exception.args[0]) + + def test_ctor_w_dotted_keys_deleted_fields(self): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + document_data = {"a.d.e": {"h.i": DELETE_FIELD}, "b": 2, "c": 3} + + with self.assertRaises(ValueError) as exc: + self._make_one(document_data) + + self.assertIn("Cannot update with nest delete", exc.exception.args[0]) + class Test_pbs_for_update(unittest.TestCase): @staticmethod @@ -2100,7 +2183,7 @@ def _call_fut(document_path, field_updates, option): return pbs_for_update(document_path, field_updates, option) - def _helper(self, option=None, do_transform=False, **write_kwargs): + def _helper(self, option=None, do_transform=False, update=True, **write_kwargs): from google.cloud.firestore_v1 import _helpers from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP @@ -2110,45 +2193,64 @@ def _helper(self, option=None, do_transform=False, **write_kwargs): from google.cloud.firestore_v1.types import write document_path = _make_ref_string(u"toy", u"car", u"onion", u"garlic") - field_path1 = "bitez.yum" - value = b"\x00\x01" + field_updates = {} + if update: + field_path1 = "bitez.yum" + value = b"\x00\x01" + field_updates[field_path1] = value field_path2 = "blog.internet" - field_updates = {field_path1: value} if do_transform: field_updates[field_path2] = SERVER_TIMESTAMP write_pbs = self._call_fut(document_path, field_updates, option) - map_pb = document.MapValue(fields={"yum": _value_pb(bytes_value=value)}) + expected_pbs = [] + if update: + map_pb = document.MapValue(fields={"yum": _value_pb(bytes_value=value)}) - field_paths = [field_path1] + field_paths = [field_path1] + + expected_update_pb = write.Write( + update=document.Document( + name=document_path, fields={"bitez": _value_pb(map_value=map_pb)} + ), + update_mask=common.DocumentMask(field_paths=field_paths), + **write_kwargs + ) + if isinstance(option, _helpers.ExistsOption): + precondition = common.Precondition(exists=False) + expected_update_pb._pb.current_document.CopyFrom(precondition._pb) + expected_pbs.append(expected_update_pb) - expected_update_pb = write.Write( - update=document.Document( - name=document_path, fields={"bitez": _value_pb(map_value=map_pb)} - ), - update_mask=common.DocumentMask(field_paths=field_paths), - **write_kwargs - ) - if isinstance(option, _helpers.ExistsOption): - precondition = common.Precondition(exists=False) - expected_update_pb._pb.current_document.CopyFrom(precondition._pb) - expected_pbs = [expected_update_pb] if do_transform: transform_paths = FieldPath.from_string(field_path2) server_val = DocumentTransform.FieldTransform.ServerValue - expected_transform_pb = write.Write( - transform=write.DocumentTransform( - document=document_path, - field_transforms=[ - write.DocumentTransform.FieldTransform( - field_path=transform_paths.to_api_repr(), - set_to_server_value=server_val.REQUEST_TIME, - ) - ], + if update: + expected_transform_pb = write.Write( + transform=write.DocumentTransform( + document=document_path, + field_transforms=[ + write.DocumentTransform.FieldTransform( + field_path=transform_paths.to_api_repr(), + set_to_server_value=server_val.REQUEST_TIME, + ) + ], + ), + ) + else: + expected_transform_pb = write.Write( + transform=write.DocumentTransform( + document=document_path, + field_transforms=[ + write.DocumentTransform.FieldTransform( + field_path=transform_paths.to_api_repr(), + set_to_server_value=server_val.REQUEST_TIME, + ) + ], + ), + current_document=common.Precondition(exists=True), ) - ) expected_pbs.append(expected_transform_pb) self.assertEqual(write_pbs, expected_pbs) @@ -2170,6 +2272,9 @@ def test_update_and_transform(self): precondition = common.Precondition(exists=True) self._helper(current_document=precondition, do_transform=True) + def test_transform(self): + self._helper(do_transform=True, update=False) + class Test_pb_for_delete(unittest.TestCase): @staticmethod diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 770d6ae204..0ea84d7e98 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -216,15 +216,15 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = [c async for c in client.collections()] + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): self.assertIsInstance(collection, AsyncCollectionReference) @@ -236,6 +236,52 @@ def _next_page(self): request={"parent": base_path}, metadata=client._rpc_metadata ) + @pytest.mark.asyncio + async def test_collections_w_next_page_token(self): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + + collection_ids = ["users", "projects"] + client = self._make_default_one() + firestore_api = AsyncMock() + firestore_api.mock_add_spec(spec=["list_collection_ids"]) + client._firestore_api_internal = firestore_api + + # TODO(microgen): list_collection_ids isn't a pager. + # https://github.com/googleapis/gapic-generator-python/issues/516 + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = "next_page_token" + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + iterator = _Iterator(pages=[collection_ids]) + iterator1 = _Iterator1(pages=[collection_ids]) + firestore_api.list_collection_ids.side_effect = [iterator, iterator1] + collections = [c async for c in client.collections()] + + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + self.assertEqual(len(collections), 4) + firestore_api.list_collection_ids.call_count = 2 + async def _get_all_helper(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = AsyncMock(spec=["batch_get_documents"]) diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 79a89d4abb..a58768069f 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -424,7 +424,7 @@ async def test_get_with_transaction(self): await self._get_helper(use_transaction=True) @pytest.mark.asyncio - async def _collections_helper(self, page_size=None): + async def _collections_helper(self, page_size=None, page_token=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.async_collection import AsyncCollectionReference @@ -437,37 +437,58 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = page_token + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) collection_ids = ["coll-1", "coll-2"] iterator = _Iterator(pages=[collection_ids]) firestore_api = AsyncMock() firestore_api.mock_add_spec(spec=["list_collection_ids"]) - firestore_api.list_collection_ids.return_value = iterator + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + + if page_token: + iterator1 = _Iterator1(pages=[collection_ids]) + firestore_api.list_collection_ids.side_effect = [iterator1, iterator] + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + else: + firestore_api.list_collection_ids.return_value = iterator client = _make_client() client._firestore_api_internal = firestore_api # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) - if page_size is not None: - collections = [c async for c in document.collections(page_size=page_size)] - else: - collections = [c async for c in document.collections()] - # Verify the response and the mocks. - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, AsyncCollectionReference) - self.assertEqual(collection.parent, document) - self.assertEqual(collection.id, collection_id) - - firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, - metadata=client._rpc_metadata, - ) + collections = [c async for c in document.collections(page_size=page_size)] + + if page_token: + self.assertEqual(len(collections), 4) + firestore_api.list_collection_ids.call_count = 2 + else: + # Verify the response and the mocks. + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, AsyncCollectionReference) + self.assertEqual(collection.parent, document) + self.assertEqual(collection.id, collection_id) + + firestore_api.list_collection_ids.assert_called_once_with( + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, + ) @pytest.mark.asyncio async def test_collections_wo_page_size(self): @@ -477,6 +498,10 @@ async def test_collections_wo_page_size(self): async def test_collections_w_page_size(self): await self._collections_helper(page_size=10) + @pytest.mark.asyncio + async def test_collections_w_next_page(self): + await self._collections_helper(page_token="next_page_token") + def _make_credentials(): import google.auth.credentials diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 14e41c2787..fb69d42f54 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -21,16 +21,6 @@ from tests.unit.v1.test_base_query import _make_credentials, _make_query_response -class MockAsyncIter: - def __init__(self, count=3): - # count is arbitrary value - self.count = count - - async def __aiter__(self, **_): - for i in range(self.count): - yield i - - class TestAsyncQuery(aiounittest.AsyncTestCase): @staticmethod def _get_target_class(): diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index 631733e075..21ddb55202 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -89,6 +89,31 @@ def test__firestore_api_property_with_emulator( self.assertIs(client._firestore_api, mock_client.return_value) self.assertEqual(mock_client.call_count, 1) + @mock.patch( + "google.cloud.firestore_v1.services.firestore.client.FirestoreClient", + autospec=True, + return_value=mock.sentinel.firestore_api, + ) + @mock.patch( + "google.cloud.firestore_v1.services.firestore.transports.grpc.FirestoreGrpcTransport.create_channel", + autospec=True, + ) + def test__firestore_api_property_with_default( + self, mock_insecure_channel, mock_client + ): + client = self._make_default_one() + self.assertIsNone(client._firestore_api_internal) + firestore_api = client._firestore_api + self.assertIs(firestore_api, mock_client.return_value) + self.assertIs(firestore_api, client._firestore_api_internal) + mock_client.assert_called_once_with( + transport=client._transport, client_options=None + ) + + # Call again to show that it is cached, but call count is still 1. + self.assertIs(client._firestore_api, mock_client.return_value) + self.assertEqual(mock_client.call_count, 1) + def test___database_string_property(self): credentials = _make_credentials() database = "cheeeeez" diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index b943fd1e14..61bd46fed8 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -212,15 +212,16 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) iterator = _Iterator(pages=[collection_ids]) firestore_api.list_collection_ids.return_value = iterator - collections = list(client.collections()) + collections = [c for c in client.collections()] + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) self.assertEqual(len(collections), len(collection_ids)) for collection, collection_id in zip(collections, collection_ids): self.assertIsInstance(collection, CollectionReference) @@ -232,6 +233,50 @@ def _next_page(self): request={"parent": base_path}, metadata=client._rpc_metadata ) + def test_collections_w_next_page_token(self): + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + + collection_ids = ["users", "projects"] + client = self._make_default_one() + firestore_api = mock.Mock(spec=["list_collection_ids"]) + client._firestore_api_internal = firestore_api + + # TODO(microgen): list_collection_ids isn't a pager. + # https://github.com/googleapis/gapic-generator-python/issues/516 + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = "next_page_token" + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + iterator = _Iterator(pages=[collection_ids]) + iterator1 = _Iterator1(pages=[collection_ids]) + firestore_api.list_collection_ids.side_effect = [iterator, iterator1] + collections = [c for c in client.collections()] + + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + self.assertEqual(len(collections), 4) + firestore_api.list_collection_ids.call_count = 2 + def _get_all_helper(self, client, references, document_pbs, **kwargs): # Create a minimal fake GAPIC with a dummy response. firestore_api = mock.Mock(spec=["batch_get_documents"]) diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index ff06532c4b..e1bd4d2f29 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -399,7 +399,7 @@ def test_get_with_multiple_field_paths(self): def test_get_with_transaction(self): self._get_helper(use_transaction=True) - def _collections_helper(self, page_size=None): + def _collections_helper(self, page_size=None, page_token=None): from google.api_core.page_iterator import Iterator from google.api_core.page_iterator import Page from google.cloud.firestore_v1.collection import CollectionReference @@ -413,36 +413,57 @@ def __init__(self, pages): self.collection_ids = pages[0] def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + class _Iterator1(Iterator): + def __init__(self, pages): + super(_Iterator1, self).__init__(client=None) + self._pages = pages + self.collection_ids = pages[0] + self.next_page_token = page_token + + def _next_page(self): + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) collection_ids = ["coll-1", "coll-2"] iterator = _Iterator(pages=[collection_ids]) api_client = mock.create_autospec(FirestoreClient) - api_client.list_collection_ids.return_value = iterator + page = iterator._next_page() + self.assertEqual(page.num_items, len(collection_ids)) + + if page_token: + iterator1 = _Iterator1(pages=[collection_ids]) + api_client.list_collection_ids.side_effect = [iterator1, iterator] + page1 = iterator1._next_page() + self.assertEqual(page1.num_items, len(collection_ids)) + else: + api_client.list_collection_ids.return_value = iterator client = _make_client() client._firestore_api_internal = api_client # Actually make a document and call delete(). document = self._make_one("where", "we-are", client=client) - if page_size is not None: - collections = list(document.collections(page_size=page_size)) - else: - collections = list(document.collections()) - # Verify the response and the mocks. - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, CollectionReference) - self.assertEqual(collection.parent, document) - self.assertEqual(collection.id, collection_id) - - api_client.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, - metadata=client._rpc_metadata, - ) + collections = list(document.collections(page_size=page_size)) + + if page_token: + self.assertEqual(len(collections), 4) + api_client.list_collection_ids.call_count = 2 + else: + # Verify the response and the mocks. + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, CollectionReference) + self.assertEqual(collection.parent, document) + self.assertEqual(collection.id, collection_id) + + api_client.list_collection_ids.assert_called_once_with( + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, + ) def test_collections_wo_page_size(self): self._collections_helper() @@ -450,6 +471,9 @@ def test_collections_wo_page_size(self): def test_collections_w_page_size(self): self._collections_helper(page_size=10) + def test_collections_w_next_page(self): + self._collections_helper(page_token="next_page_token") + @mock.patch("google.cloud.firestore_v1.document.Watch", autospec=True) def test_on_snapshot(self, watch): client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) diff --git a/tests/unit/v1/test_order.py b/tests/unit/v1/test_order.py index 4db743221c..90d99e563e 100644 --- a/tests/unit/v1/test_order.py +++ b/tests/unit/v1/test_order.py @@ -207,8 +207,6 @@ def _int_value(value): def _string_value(s): - if not isinstance(s, str): - s = str(s) return encode_value(s) diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 759549b72a..d8a97a8af1 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -374,6 +374,18 @@ def test_on_snapshot_target_add(self): inst.on_snapshot(proto) self.assertEqual(str(exc.exception), "Unexpected target ID 1 sent by server") + def test_on_snapshot_target_add_successfully(self): + from google.cloud.firestore_v1.watch import WATCH_TARGET_ID + + inst = self._makeOne() + proto = DummyProto() + proto.target_change.target_change_type = ( + firestore.TargetChange.TargetChangeType.ADD + ) + proto.target_change.target_ids = [20601] # not "Py" + inst.on_snapshot(proto) + self.assertEqual(proto.target_change.target_ids[0], WATCH_TARGET_ID) + def test_on_snapshot_target_remove(self): inst = self._makeOne() proto = DummyProto() @@ -565,6 +577,23 @@ def test_on_snapshot_unknown_listen_type(self): str(exc.exception), ) + def test_on_snapshot_document_delete(self): + from google.cloud.firestore_v1.watch import ChangeType + + inst = self._makeOne() + proto = DummyProto() + proto.target_change = "" + proto.document_change = "" + + class DummyDelete(object): + document = "fred" + + delete = DummyDelete() + proto.document_remove = "" + proto.document_delete = delete + inst.on_snapshot(proto) + self.assertTrue(inst.change_map["fred"] is ChangeType.REMOVED) + def test_push_callback_called_no_changes(self): import pytz