Skip to content

Commit 3ac80a9

Browse files
feat: replace internal dictionaries with protos in gapic calls (#875)
1 parent 94bfe66 commit 3ac80a9

File tree

9 files changed

+253
-101
lines changed

9 files changed

+253
-101
lines changed

google/cloud/bigtable/data/_async/_mutate_rows.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@
1616

1717
from typing import TYPE_CHECKING
1818
import asyncio
19+
from dataclasses import dataclass
1920
import functools
2021

2122
from google.api_core import exceptions as core_exceptions
2223
from google.api_core import retry_async as retries
24+
import google.cloud.bigtable_v2.types.bigtable as types_pb
2325
import google.cloud.bigtable.data.exceptions as bt_exceptions
2426
from google.cloud.bigtable.data._helpers import _make_metadata
2527
from google.cloud.bigtable.data._helpers import _convert_retry_deadline
@@ -36,6 +38,16 @@
3638
from google.cloud.bigtable.data._async.client import TableAsync
3739

3840

41+
@dataclass
42+
class _EntryWithProto:
43+
"""
44+
A dataclass to hold a RowMutationEntry and its corresponding proto representation.
45+
"""
46+
47+
entry: RowMutationEntry
48+
proto: types_pb.MutateRowsRequest.Entry
49+
50+
3951
class _MutateRowsOperationAsync:
4052
"""
4153
MutateRowsOperation manages the logic of sending a set of row mutations,
@@ -105,7 +117,7 @@ def __init__(
105117
self.timeout_generator = _attempt_timeout_generator(
106118
attempt_timeout, operation_timeout
107119
)
108-
self.mutations = mutation_entries
120+
self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries]
109121
self.remaining_indices = list(range(len(self.mutations)))
110122
self.errors: dict[int, list[Exception]] = {}
111123

@@ -136,7 +148,7 @@ async def start(self):
136148
cause_exc = exc_list[0]
137149
else:
138150
cause_exc = bt_exceptions.RetryExceptionGroup(exc_list)
139-
entry = self.mutations[idx]
151+
entry = self.mutations[idx].entry
140152
all_errors.append(
141153
bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc)
142154
)
@@ -154,9 +166,7 @@ async def _run_attempt(self):
154166
retry after the attempt is complete
155167
- GoogleAPICallError: if the gapic rpc fails
156168
"""
157-
request_entries = [
158-
self.mutations[idx]._to_dict() for idx in self.remaining_indices
159-
]
169+
request_entries = [self.mutations[idx].proto for idx in self.remaining_indices]
160170
# track mutations in this request that have not been finalized yet
161171
active_request_indices = {
162172
req_idx: orig_idx for req_idx, orig_idx in enumerate(self.remaining_indices)
@@ -214,7 +224,7 @@ def _handle_entry_error(self, idx: int, exc: Exception):
214224
- idx: the index of the mutation that failed
215225
- exc: the exception to add to the list
216226
"""
217-
entry = self.mutations[idx]
227+
entry = self.mutations[idx].entry
218228
self.errors.setdefault(idx, []).append(exc)
219229
if (
220230
entry.is_idempotent()

google/cloud/bigtable/data/_async/client.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -924,22 +924,17 @@ async def mutate_row(
924924
GoogleAPIError exceptions from any retries that failed
925925
- GoogleAPIError: raised on non-idempotent operations that cannot be
926926
safely retried.
927+
- ValueError if invalid arguments are provided
927928
"""
928929
operation_timeout, attempt_timeout = _get_timeouts(
929930
operation_timeout, attempt_timeout, self
930931
)
931932

932-
if isinstance(row_key, str):
933-
row_key = row_key.encode("utf-8")
934-
request = {"table_name": self.table_name, "row_key": row_key}
935-
if self.app_profile_id:
936-
request["app_profile_id"] = self.app_profile_id
933+
if not mutations:
934+
raise ValueError("No mutations provided")
935+
mutations_list = mutations if isinstance(mutations, list) else [mutations]
937936

938-
if isinstance(mutations, Mutation):
939-
mutations = [mutations]
940-
request["mutations"] = [mutation._to_dict() for mutation in mutations]
941-
942-
if all(mutation.is_idempotent() for mutation in mutations):
937+
if all(mutation.is_idempotent() for mutation in mutations_list):
943938
# mutations are all idempotent and safe to retry
944939
predicate = retries.if_exception_type(
945940
core_exceptions.DeadlineExceeded,
@@ -972,7 +967,13 @@ def on_error_fn(exc):
972967
metadata = _make_metadata(self.table_name, self.app_profile_id)
973968
# trigger rpc
974969
await deadline_wrapped(
975-
request, timeout=attempt_timeout, metadata=metadata, retry=None
970+
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
971+
mutations=[mutation._to_pb() for mutation in mutations_list],
972+
table_name=self.table_name,
973+
app_profile_id=self.app_profile_id,
974+
timeout=attempt_timeout,
975+
metadata=metadata,
976+
retry=None,
976977
)
977978

978979
async def bulk_mutate_rows(
@@ -1009,6 +1010,7 @@ async def bulk_mutate_rows(
10091010
Raises:
10101011
- MutationsExceptionGroup if one or more mutations fails
10111012
Contains details about any failed entries in .exceptions
1013+
- ValueError if invalid arguments are provided
10121014
"""
10131015
operation_timeout, attempt_timeout = _get_timeouts(
10141016
operation_timeout, attempt_timeout, self
@@ -1065,29 +1067,24 @@ async def check_and_mutate_row(
10651067
- GoogleAPIError exceptions from grpc call
10661068
"""
10671069
operation_timeout, _ = _get_timeouts(operation_timeout, None, self)
1068-
row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key
10691070
if true_case_mutations is not None and not isinstance(
10701071
true_case_mutations, list
10711072
):
10721073
true_case_mutations = [true_case_mutations]
1073-
true_case_dict = [m._to_dict() for m in true_case_mutations or []]
1074+
true_case_list = [m._to_pb() for m in true_case_mutations or []]
10741075
if false_case_mutations is not None and not isinstance(
10751076
false_case_mutations, list
10761077
):
10771078
false_case_mutations = [false_case_mutations]
1078-
false_case_dict = [m._to_dict() for m in false_case_mutations or []]
1079+
false_case_list = [m._to_pb() for m in false_case_mutations or []]
10791080
metadata = _make_metadata(self.table_name, self.app_profile_id)
10801081
result = await self.client._gapic_client.check_and_mutate_row(
1081-
request={
1082-
"predicate_filter": predicate._to_dict()
1083-
if predicate is not None
1084-
else None,
1085-
"true_mutations": true_case_dict,
1086-
"false_mutations": false_case_dict,
1087-
"table_name": self.table_name,
1088-
"row_key": row_key,
1089-
"app_profile_id": self.app_profile_id,
1090-
},
1082+
true_mutations=true_case_list,
1083+
false_mutations=false_case_list,
1084+
predicate_filter=predicate._to_pb() if predicate is not None else None,
1085+
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
1086+
table_name=self.table_name,
1087+
app_profile_id=self.app_profile_id,
10911088
metadata=metadata,
10921089
timeout=operation_timeout,
10931090
retry=None,
@@ -1123,25 +1120,21 @@ async def read_modify_write_row(
11231120
operation
11241121
Raises:
11251122
- GoogleAPIError exceptions from grpc call
1123+
- ValueError if invalid arguments are provided
11261124
"""
11271125
operation_timeout, _ = _get_timeouts(operation_timeout, None, self)
1128-
row_key = row_key.encode("utf-8") if isinstance(row_key, str) else row_key
11291126
if operation_timeout <= 0:
11301127
raise ValueError("operation_timeout must be greater than 0")
11311128
if rules is not None and not isinstance(rules, list):
11321129
rules = [rules]
11331130
if not rules:
11341131
raise ValueError("rules must contain at least one item")
1135-
# concert to dict representation
1136-
rules_dict = [rule._to_dict() for rule in rules]
11371132
metadata = _make_metadata(self.table_name, self.app_profile_id)
11381133
result = await self.client._gapic_client.read_modify_write_row(
1139-
request={
1140-
"rules": rules_dict,
1141-
"table_name": self.table_name,
1142-
"row_key": row_key,
1143-
"app_profile_id": self.app_profile_id,
1144-
},
1134+
rules=[rule._to_pb() for rule in rules],
1135+
row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key,
1136+
table_name=self.table_name,
1137+
app_profile_id=self.app_profile_id,
11451138
metadata=metadata,
11461139
timeout=operation_timeout,
11471140
retry=None,

google/cloud/bigtable/data/_async/mutations_batcher.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,6 @@ async def _execute_mutate_rows(
342342
- list of FailedMutationEntryError objects for mutations that failed.
343343
FailedMutationEntryError objects will not contain index information
344344
"""
345-
request = {"table_name": self._table.table_name}
346-
if self._table.app_profile_id:
347-
request["app_profile_id"] = self._table.app_profile_id
348345
try:
349346
operation = _MutateRowsOperationAsync(
350347
self._table.client._gapic_client,

google/cloud/bigtable/data/mutations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
from abc import ABC, abstractmethod
2020
from sys import getsizeof
2121

22+
import google.cloud.bigtable_v2.types.bigtable as types_pb
23+
import google.cloud.bigtable_v2.types.data as data_pb
2224

2325
from google.cloud.bigtable.data.read_modify_write_rules import _MAX_INCREMENT_VALUE
2426

27+
2528
# special value for SetCell mutation timestamps. If set, server will assign a timestamp
2629
_SERVER_SIDE_TIMESTAMP = -1
2730

@@ -36,6 +39,12 @@ class Mutation(ABC):
3639
def _to_dict(self) -> dict[str, Any]:
3740
raise NotImplementedError
3841

42+
def _to_pb(self) -> data_pb.Mutation:
43+
"""
44+
Convert the mutation to protobuf
45+
"""
46+
return data_pb.Mutation(**self._to_dict())
47+
3948
def is_idempotent(self) -> bool:
4049
"""
4150
Check if the mutation is idempotent
@@ -221,6 +230,12 @@ def _to_dict(self) -> dict[str, Any]:
221230
"mutations": [mutation._to_dict() for mutation in self.mutations],
222231
}
223232

233+
def _to_pb(self) -> types_pb.MutateRowsRequest.Entry:
234+
return types_pb.MutateRowsRequest.Entry(
235+
row_key=self.row_key,
236+
mutations=[mutation._to_pb() for mutation in self.mutations],
237+
)
238+
224239
def is_idempotent(self) -> bool:
225240
"""Check if the mutation is idempotent"""
226241
return all(mutation.is_idempotent() for mutation in self.mutations)

google/cloud/bigtable/data/read_modify_write_rules.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import abc
1818

19+
import google.cloud.bigtable_v2.types.data as data_pb
20+
1921
# value must fit in 64-bit signed integer
2022
_MAX_INCREMENT_VALUE = (1 << 63) - 1
2123

@@ -29,9 +31,12 @@ def __init__(self, family: str, qualifier: bytes | str):
2931
self.qualifier = qualifier
3032

3133
@abc.abstractmethod
32-
def _to_dict(self):
34+
def _to_dict(self) -> dict[str, str | bytes | int]:
3335
raise NotImplementedError
3436

37+
def _to_pb(self) -> data_pb.ReadModifyWriteRule:
38+
return data_pb.ReadModifyWriteRule(**self._to_dict())
39+
3540

3641
class IncrementRule(ReadModifyWriteRule):
3742
def __init__(self, family: str, qualifier: bytes | str, increment_amount: int = 1):
@@ -44,7 +49,7 @@ def __init__(self, family: str, qualifier: bytes | str, increment_amount: int =
4449
super().__init__(family, qualifier)
4550
self.increment_amount = increment_amount
4651

47-
def _to_dict(self):
52+
def _to_dict(self) -> dict[str, str | bytes | int]:
4853
return {
4954
"family_name": self.family,
5055
"column_qualifier": self.qualifier,
@@ -64,7 +69,7 @@ def __init__(self, family: str, qualifier: bytes | str, append_value: bytes | st
6469
super().__init__(family, qualifier)
6570
self.append_value = append_value
6671

67-
def _to_dict(self):
72+
def _to_dict(self) -> dict[str, str | bytes | int]:
6873
return {
6974
"family_name": self.family,
7075
"column_qualifier": self.qualifier,

tests/unit/data/_async/test__mutate_rows.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def test_ctor(self):
7575
"""
7676
test that constructor sets all the attributes correctly
7777
"""
78+
from google.cloud.bigtable.data._async._mutate_rows import _EntryWithProto
7879
from google.cloud.bigtable.data.exceptions import _MutateRowsIncomplete
7980
from google.api_core.exceptions import DeadlineExceeded
8081
from google.api_core.exceptions import ServiceUnavailable
@@ -103,7 +104,8 @@ def test_ctor(self):
103104
assert str(table.table_name) in metadata[0][1]
104105
assert str(table.app_profile_id) in metadata[0][1]
105106
# entries should be passed down
106-
assert instance.mutations == entries
107+
entries_w_pb = [_EntryWithProto(e, e._to_pb()) for e in entries]
108+
assert instance.mutations == entries_w_pb
107109
# timeout_gen should generate per-attempt timeout
108110
assert next(instance.timeout_generator) == attempt_timeout
109111
# ensure predicate is set
@@ -306,7 +308,7 @@ async def test_run_attempt_single_entry_success(self):
306308
assert mock_gapic_fn.call_count == 1
307309
_, kwargs = mock_gapic_fn.call_args
308310
assert kwargs["timeout"] == expected_timeout
309-
assert kwargs["entries"] == [mutation._to_dict()]
311+
assert kwargs["entries"] == [mutation._to_pb()]
310312

311313
@pytest.mark.asyncio
312314
async def test_run_attempt_empty_request(self):

0 commit comments

Comments
 (0)