Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions google/cloud/firestore_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def get_update_pb(
current_document = common.Precondition(exists=exists)
else:
current_document = None

update_pb = write.Write(
update=document.Document(
name=document_path, fields=encode_dict(self.set_fields)
Expand Down Expand Up @@ -649,7 +648,6 @@ def has_updates(self):
# of nested transform paths in the update mask
# (see set-st-merge-nonleaf-alone.textproto)
update_paths = set(self.data_merge)

for transform_path in self.transform_paths:
if len(transform_path.parts) > 1:
parent_fp = FieldPath(*transform_path.parts[:-1])
Expand Down
89 changes: 86 additions & 3 deletions tests/unit/v1/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,11 +1569,13 @@ def _add_field_transforms(update_pb, fields):
)
)

def _helper(self, do_transform=False, empty_val=False):
def _helper(self, do_transform=False, empty_val=False, set_field=True):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP

document_path = _make_ref_string(u"little", u"town", u"of", u"ham")
document_data = {"cheese": 1.5, "crackers": True}
document_data = {}
if set_field:
document_data = {"cheese": 1.5, "crackers": True}

if do_transform:
document_data["butter"] = SERVER_TIMESTAMP
Expand All @@ -1592,7 +1594,6 @@ def _helper(self, do_transform=False, empty_val=False):
document_path, cheese=1.5, crackers=True
)
expected_pbs = [update_pb]

if do_transform:
self._add_field_transforms(update_pb, fields=["butter"])

Expand All @@ -1607,6 +1608,23 @@ def test_w_transform(self):
def test_w_transform_and_empty_value(self):
self._helper(do_transform=True, empty_val=True)

def test_w_deleted_fields(self):
from google.cloud.firestore_v1.transforms import DELETE_FIELD

document_data = {
"write_me": "value",
"delete_me": DELETE_FIELD,
"ignore_me": 123,
}
document_path = _make_ref_string(u"little", u"town", u"of", u"ham")

with self.assertRaises(ValueError) as exc:
self._call_fut(document_path, document_data)

self.assertEqual(
exc.exception.args, ("Cannot apply DELETE_FIELD in a create request.",)
)


class Test_pbs_for_set_no_merge(unittest.TestCase):
@staticmethod
Expand Down Expand Up @@ -1699,6 +1717,23 @@ def test_w_transform_and_empty_value(self):
# Exercise #5944
self._helper(do_transform=True, empty_val=True)

def test_w_deleted_fields(self):
from google.cloud.firestore_v1.transforms import DELETE_FIELD

document_data = {
"write_me": "value",
"delete_me": DELETE_FIELD,
"ignore_me": 123,
}
document_path = _make_ref_string(u"little", u"town", u"of", u"ham")

with self.assertRaises(ValueError) as exc:
self._call_fut(document_path, document_data)

self.assertIn(
"Cannot apply DELETE_FIELD in a set request", exc.exception.args[0]
)


class TestDocumentExtractorForMerge(unittest.TestCase):
@staticmethod
Expand Down Expand Up @@ -1897,6 +1932,36 @@ def test_apply_merge_list_fields_w_array_union(self):
self.assertEqual(inst.array_unions, expected_array_unions)
self.assertTrue(inst.has_updates)

def test_apply_merge_multiple_transform_paths(self):
from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP, ArrayUnion

values = [1, 3, 5]
document_data = {
"write_me": "value",
"b": {"timestamp": SERVER_TIMESTAMP},
"union_me": ArrayUnion(values),
}
inst = self._make_one(document_data)

inst.apply_merge(True)

expected_data_merge = [_make_field_path("write_me")]
expected_transform_merge = [
_make_field_path("b", "timestamp"),
_make_field_path("union_me"),
]
expected_merge = [
_make_field_path("b", "timestamp"),
_make_field_path("union_me"),
_make_field_path("write_me"),
]
self.assertEqual(inst.data_merge, expected_data_merge)
self.assertEqual(inst.transform_merge, expected_transform_merge)
self.assertEqual(inst.merge, expected_merge)
expected_array_unions = {_make_field_path("union_me"): values}
self.assertEqual(inst.array_unions, expected_array_unions)
self.assertTrue(inst.has_updates)


class Test_pbs_for_set_with_merge(unittest.TestCase):
@staticmethod
Expand Down Expand Up @@ -2110,6 +2175,24 @@ def test_ctor_w_nested_dotted_keys(self):
self.assertEqual(inst.top_level_paths, expected_paths)
self.assertEqual(inst.set_fields, expected_set_fields)

def test_ctor_w_dotted_keys_conflict(self):
document_data = {"a.d": {"h.i": 9}, "a.d.e": 1, "b.f": 7, "c": 3}

with self.assertRaises(ValueError) as exc:
self._make_one(document_data)

self.assertIn("Conflicting field path", exc.exception.args[0])

def test_ctor_w_dotted_keys_deleted_fields(self):
from google.cloud.firestore_v1.transforms import DELETE_FIELD

document_data = {"a.d.e": {"h.i": DELETE_FIELD}, "b": 2, "c": 3}

with self.assertRaises(ValueError) as exc:
self._make_one(document_data)

self.assertIn("Cannot update with nest delete", exc.exception.args[0])


class Test_pbs_for_update(unittest.TestCase):
@staticmethod
Expand Down
54 changes: 50 additions & 4 deletions tests/unit/v1/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,15 @@ def __init__(self, pages):
self.collection_ids = pages[0]

def _next_page(self):
if self._pages:
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)

iterator = _Iterator(pages=[collection_ids])
firestore_api.list_collection_ids.return_value = iterator

collections = [c async for c in client.collections()]

page = iterator._next_page()
self.assertEqual(page.num_items, len(collection_ids))
self.assertEqual(len(collections), len(collection_ids))
for collection, collection_id in zip(collections, collection_ids):
self.assertIsInstance(collection, AsyncCollectionReference)
Expand All @@ -236,6 +236,52 @@ def _next_page(self):
request={"parent": base_path}, metadata=client._rpc_metadata
)

@pytest.mark.asyncio
async def test_collections_w_next_page_token(self):
from google.api_core.page_iterator import Iterator
from google.api_core.page_iterator import Page

collection_ids = ["users", "projects"]
client = self._make_default_one()
firestore_api = AsyncMock()
firestore_api.mock_add_spec(spec=["list_collection_ids"])
client._firestore_api_internal = firestore_api

# TODO(microgen): list_collection_ids isn't a pager.
# https://github.com/googleapis/gapic-generator-python/issues/516
class _Iterator(Iterator):
def __init__(self, pages):
super(_Iterator, self).__init__(client=None)
self._pages = pages
self.collection_ids = pages[0]
self.next_page_token = "next_page_token"

def _next_page(self):
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)

class _Iterator1(Iterator):
def __init__(self, pages):
super(_Iterator1, self).__init__(client=None)
self._pages = pages
self.collection_ids = pages[0]

def _next_page(self):
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)

iterator = _Iterator(pages=[collection_ids])
iterator1 = _Iterator1(pages=[collection_ids])
firestore_api.list_collection_ids.side_effect = [iterator, iterator1]
collections = [c async for c in client.collections()]

page = iterator._next_page()
self.assertEqual(page.num_items, len(collection_ids))
page1 = iterator1._next_page()
self.assertEqual(page1.num_items, len(collection_ids))
self.assertEqual(len(collections), 4)
firestore_api.list_collection_ids.call_count = 2

async def _get_all_helper(self, client, references, document_pbs, **kwargs):
# Create a minimal fake GAPIC with a dummy response.
firestore_api = AsyncMock(spec=["batch_get_documents"])
Expand Down
65 changes: 45 additions & 20 deletions tests/unit/v1/test_async_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ async def test_get_with_transaction(self):
await self._get_helper(use_transaction=True)

@pytest.mark.asyncio
async def _collections_helper(self, page_size=None):
async def _collections_helper(self, page_size=None, page_token=None):
from google.api_core.page_iterator import Iterator
from google.api_core.page_iterator import Page
from google.cloud.firestore_v1.async_collection import AsyncCollectionReference
Expand All @@ -437,37 +437,58 @@ def __init__(self, pages):
self.collection_ids = pages[0]

def _next_page(self):
if self._pages:
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)

class _Iterator1(Iterator):
def __init__(self, pages):
super(_Iterator1, self).__init__(client=None)
self._pages = pages
self.collection_ids = pages[0]
self.next_page_token = page_token

def _next_page(self):
page, self._pages = self._pages[0], self._pages[1:]
return Page(self, page, self.item_to_value)

collection_ids = ["coll-1", "coll-2"]
iterator = _Iterator(pages=[collection_ids])
firestore_api = AsyncMock()
firestore_api.mock_add_spec(spec=["list_collection_ids"])
firestore_api.list_collection_ids.return_value = iterator
page = iterator._next_page()
self.assertEqual(page.num_items, len(collection_ids))

if page_token:
iterator1 = _Iterator1(pages=[collection_ids])
firestore_api.list_collection_ids.side_effect = [iterator1, iterator]
page1 = iterator1._next_page()
self.assertEqual(page1.num_items, len(collection_ids))
else:
firestore_api.list_collection_ids.return_value = iterator

client = _make_client()
client._firestore_api_internal = firestore_api

# Actually make a document and call delete().
document = self._make_one("where", "we-are", client=client)
if page_size is not None:
collections = [c async for c in document.collections(page_size=page_size)]
else:
collections = [c async for c in document.collections()]

# Verify the response and the mocks.
self.assertEqual(len(collections), len(collection_ids))
for collection, collection_id in zip(collections, collection_ids):
self.assertIsInstance(collection, AsyncCollectionReference)
self.assertEqual(collection.parent, document)
self.assertEqual(collection.id, collection_id)

firestore_api.list_collection_ids.assert_called_once_with(
request={"parent": document._document_path, "page_size": page_size},
metadata=client._rpc_metadata,
)
collections = [c async for c in document.collections(page_size=page_size)]

if page_token:
self.assertEqual(len(collections), 4)
firestore_api.list_collection_ids.call_count = 2
else:
# Verify the response and the mocks.
self.assertEqual(len(collections), len(collection_ids))
for collection, collection_id in zip(collections, collection_ids):
self.assertIsInstance(collection, AsyncCollectionReference)
self.assertEqual(collection.parent, document)
self.assertEqual(collection.id, collection_id)

firestore_api.list_collection_ids.assert_called_once_with(
request={"parent": document._document_path, "page_size": page_size},
metadata=client._rpc_metadata,
)

@pytest.mark.asyncio
async def test_collections_wo_page_size(self):
Expand All @@ -477,6 +498,10 @@ async def test_collections_wo_page_size(self):
async def test_collections_w_page_size(self):
await self._collections_helper(page_size=10)

@pytest.mark.asyncio
async def test_collections_w_next_page(self):
await self._collections_helper(page_token="next_page_token")


def _make_credentials():
import google.auth.credentials
Expand Down
10 changes: 0 additions & 10 deletions tests/unit/v1/test_async_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@
)


class MockAsyncIter:
def __init__(self, count=3):
# count is arbitrary value
self.count = count

async def __aiter__(self, **_):
for i in range(self.count):
yield i


class TestAsyncQuery(aiounittest.AsyncTestCase):
@staticmethod
def _get_target_class():
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/v1/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,31 @@ def test__firestore_api_property_with_emulator(
self.assertIs(client._firestore_api, mock_client.return_value)
self.assertEqual(mock_client.call_count, 1)

@mock.patch(
"google.cloud.firestore_v1.services.firestore.client.FirestoreClient",
autospec=True,
return_value=mock.sentinel.firestore_api,
)
@mock.patch(
"google.cloud.firestore_v1.services.firestore.transports.grpc.FirestoreGrpcTransport.create_channel",
autospec=True,
)
def test__firestore_api_property_with_default(
self, mock_insecure_channel, mock_client
):
client = self._make_default_one()
self.assertIsNone(client._firestore_api_internal)
firestore_api = client._firestore_api
self.assertIs(firestore_api, mock_client.return_value)
self.assertIs(firestore_api, client._firestore_api_internal)
mock_client.assert_called_once_with(
transport=client._transport, client_options=None
)

# Call again to show that it is cached, but call count is still 1.
self.assertIs(client._firestore_api, mock_client.return_value)
self.assertEqual(mock_client.call_count, 1)

def test___database_string_property(self):
credentials = _make_credentials()
database = "cheeeeez"
Expand Down
Loading