Skip to content

Commit 02a4317

Browse files
refactor: added BaseQuery._copy method (#406)
* refactor: added BaseQuery.copy method * 🦉 Updates from OwlBot See https://github.com/googleapis/repo-automation-bots/blob/master/packages/owl-bot/README.md * responded to code review * migrated last copy location * moved _not_passed check to identity instead of equality Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 90d0af3 commit 02a4317

File tree

1 file changed

+39
-70
lines changed
  • packages/google-cloud-firestore/google/cloud/firestore_v1

1 file changed

+39
-70
lines changed

packages/google-cloud-firestore/google/cloud/firestore_v1/base_query.py

Lines changed: 39 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@
8585
)
8686
_MISMATCH_CURSOR_W_ORDER_BY = "The cursor {!r} does not match the order fields {!r}."
8787

88+
_not_passed = object()
89+
8890

8991
class BaseQuery(object):
9092
"""Represents a query to the Firestore API.
@@ -231,19 +233,41 @@ def select(self, field_paths: Iterable[str]) -> "BaseQuery":
231233
for field_path in field_paths
232234
]
233235
)
236+
return self._copy(projection=new_projection)
237+
238+
def _copy(
239+
self,
240+
*,
241+
projection: Optional[query.StructuredQuery.Projection] = _not_passed,
242+
field_filters: Optional[Tuple[query.StructuredQuery.FieldFilter]] = _not_passed,
243+
orders: Optional[Tuple[query.StructuredQuery.Order]] = _not_passed,
244+
limit: Optional[int] = _not_passed,
245+
limit_to_last: Optional[bool] = _not_passed,
246+
offset: Optional[int] = _not_passed,
247+
start_at: Optional[Tuple[dict, bool]] = _not_passed,
248+
end_at: Optional[Tuple[dict, bool]] = _not_passed,
249+
all_descendants: Optional[bool] = _not_passed,
250+
) -> "BaseQuery":
234251
return self.__class__(
235252
self._parent,
236-
projection=new_projection,
237-
field_filters=self._field_filters,
238-
orders=self._orders,
239-
limit=self._limit,
240-
limit_to_last=self._limit_to_last,
241-
offset=self._offset,
242-
start_at=self._start_at,
243-
end_at=self._end_at,
244-
all_descendants=self._all_descendants,
253+
projection=self._evaluate_param(projection, self._projection),
254+
field_filters=self._evaluate_param(field_filters, self._field_filters),
255+
orders=self._evaluate_param(orders, self._orders),
256+
limit=self._evaluate_param(limit, self._limit),
257+
limit_to_last=self._evaluate_param(limit_to_last, self._limit_to_last),
258+
offset=self._evaluate_param(offset, self._offset),
259+
start_at=self._evaluate_param(start_at, self._start_at),
260+
end_at=self._evaluate_param(end_at, self._end_at),
261+
all_descendants=self._evaluate_param(
262+
all_descendants, self._all_descendants
263+
),
245264
)
246265

266+
def _evaluate_param(self, value, fallback_value):
267+
"""Helper which allows `None` to be passed into `copy` and be set on the
268+
copy instead of being misinterpreted as an unpassed parameter."""
269+
return value if value is not _not_passed else fallback_value
270+
247271
def where(self, field_path: str, op_string: str, value) -> "BaseQuery":
248272
"""Filter the query on a field.
249273
@@ -301,18 +325,7 @@ def where(self, field_path: str, op_string: str, value) -> "BaseQuery":
301325
)
302326

303327
new_filters = self._field_filters + (filter_pb,)
304-
return self.__class__(
305-
self._parent,
306-
projection=self._projection,
307-
field_filters=new_filters,
308-
orders=self._orders,
309-
limit=self._limit,
310-
offset=self._offset,
311-
limit_to_last=self._limit_to_last,
312-
start_at=self._start_at,
313-
end_at=self._end_at,
314-
all_descendants=self._all_descendants,
315-
)
328+
return self._copy(field_filters=new_filters)
316329

317330
@staticmethod
318331
def _make_order(field_path, direction) -> StructuredQuery.Order:
@@ -354,18 +367,7 @@ def order_by(self, field_path: str, direction: str = ASCENDING) -> "BaseQuery":
354367
order_pb = self._make_order(field_path, direction)
355368

356369
new_orders = self._orders + (order_pb,)
357-
return self.__class__(
358-
self._parent,
359-
projection=self._projection,
360-
field_filters=self._field_filters,
361-
orders=new_orders,
362-
limit=self._limit,
363-
limit_to_last=self._limit_to_last,
364-
offset=self._offset,
365-
start_at=self._start_at,
366-
end_at=self._end_at,
367-
all_descendants=self._all_descendants,
368-
)
370+
return self._copy(orders=new_orders)
369371

370372
def limit(self, count: int) -> "BaseQuery":
371373
"""Limit a query to return at most `count` matching results.
@@ -384,18 +386,7 @@ def limit(self, count: int) -> "BaseQuery":
384386
A limited query. Acts as a copy of the current query, modified
385387
with the newly added "limit" filter.
386388
"""
387-
return self.__class__(
388-
self._parent,
389-
projection=self._projection,
390-
field_filters=self._field_filters,
391-
orders=self._orders,
392-
limit=count,
393-
limit_to_last=False,
394-
offset=self._offset,
395-
start_at=self._start_at,
396-
end_at=self._end_at,
397-
all_descendants=self._all_descendants,
398-
)
389+
return self._copy(limit=count, limit_to_last=False)
399390

400391
def limit_to_last(self, count: int) -> "BaseQuery":
401392
"""Limit a query to return the last `count` matching results.
@@ -414,18 +405,7 @@ def limit_to_last(self, count: int) -> "BaseQuery":
414405
A limited query. Acts as a copy of the current query, modified
415406
with the newly added "limit" filter.
416407
"""
417-
return self.__class__(
418-
self._parent,
419-
projection=self._projection,
420-
field_filters=self._field_filters,
421-
orders=self._orders,
422-
limit=count,
423-
limit_to_last=True,
424-
offset=self._offset,
425-
start_at=self._start_at,
426-
end_at=self._end_at,
427-
all_descendants=self._all_descendants,
428-
)
408+
return self._copy(limit=count, limit_to_last=True)
429409

430410
def offset(self, num_to_skip: int) -> "BaseQuery":
431411
"""Skip to an offset in a query.
@@ -442,18 +422,7 @@ def offset(self, num_to_skip: int) -> "BaseQuery":
442422
An offset query. Acts as a copy of the current query, modified
443423
with the newly added "offset" field.
444424
"""
445-
return self.__class__(
446-
self._parent,
447-
projection=self._projection,
448-
field_filters=self._field_filters,
449-
orders=self._orders,
450-
limit=self._limit,
451-
limit_to_last=self._limit_to_last,
452-
offset=num_to_skip,
453-
start_at=self._start_at,
454-
end_at=self._end_at,
455-
all_descendants=self._all_descendants,
456-
)
425+
return self._copy(offset=num_to_skip)
457426

458427
def _check_snapshot(self, document_snapshot) -> None:
459428
"""Validate local snapshots for non-collection-group queries.
@@ -523,7 +492,7 @@ def _cursor_helper(
523492
query_kwargs["start_at"] = self._start_at
524493
query_kwargs["end_at"] = cursor_pair
525494

526-
return self.__class__(self._parent, **query_kwargs)
495+
return self._copy(**query_kwargs)
527496

528497
def start_at(
529498
self, document_fields_or_snapshot: Union[DocumentSnapshot, dict, list, tuple]

0 commit comments

Comments
 (0)