1515
1616from __future__ import annotations
1717
18- from typing import TYPE_CHECKING
18+ from typing import TYPE_CHECKING , AsyncGenerator , AsyncIterable , Awaitable
1919
2020from google .cloud .bigtable_v2 .types import ReadRowsRequest as ReadRowsRequestPB
21+ from google .cloud .bigtable_v2 .types import ReadRowsResponse as ReadRowsResponsePB
2122from google .cloud .bigtable_v2 .types import RowSet as RowSetPB
2223from google .cloud .bigtable_v2 .types import RowRange as RowRangePB
2324
3940
4041
4142class _ResetRow (Exception ):
42- pass
43+ def __init__ (self , chunk ):
44+ self .chunk = chunk
4345
4446
4547class _ReadRowsOperationAsync :
@@ -83,7 +85,10 @@ def __init__(
8385 self ._last_yielded_row_key : bytes | None = None
8486 self ._remaining_count = self .request .rows_limit or None
8587
86- async def start_operation (self ):
88+ async def start_operation (self ) -> AsyncGenerator [Row , None ]:
89+ """
90+ Start the read_rows operation, retrying on retryable errors.
91+ """
8792 transient_errors = []
8893
8994 def on_error_fn (exc ):
@@ -103,7 +108,12 @@ def on_error_fn(exc):
103108 except core_exceptions .RetryError :
104109 self ._raise_retry_error (transient_errors )
105110
106- def read_rows_attempt (self ):
111+ def read_rows_attempt (self ) -> AsyncGenerator [Row , None ]:
112+ """
113+ single read_rows attempt. This function is intended to be wrapped
114+ by retry logic to be called for each attempted.
115+ """
116+ # revise request keys and ranges between attempts
107117 if self ._last_yielded_row_key is not None :
108118 # if this is a retry, try to trim down the request to avoid ones we've already processed
109119 try :
@@ -112,22 +122,32 @@ def read_rows_attempt(self):
112122 last_seen_row_key = self ._last_yielded_row_key ,
113123 )
114124 except _RowSetComplete :
115- return
125+ # if we've already seen all the rows, we're done
126+ return self .merge_rows (None )
127+ # revise the limit based on number of rows already yielded
116128 if self ._remaining_count is not None :
117129 self .request .rows_limit = self ._remaining_count
118- s = self .table .client ._gapic_client .read_rows (
130+ # create and return a new row merger
131+ gapic_stream = self .table .client ._gapic_client .read_rows (
119132 self .request ,
120133 timeout = next (self .attempt_timeout_gen ),
121134 metadata = self ._metadata ,
122135 )
123- s = self .chunk_stream (s )
124- return self .merge_rows (s )
125-
136+ chunked_stream = self .chunk_stream (gapic_stream )
137+ return self .merge_rows (chunked_stream )
126138
127- async def chunk_stream (self , stream ):
139+ async def chunk_stream (
140+ self , stream : Awaitable [AsyncIterable [ReadRowsResponsePB ]]
141+ ) -> AsyncGenerator [ReadRowsResponsePB .CellChunk , None ]:
142+ """
143+ process chunks out of raw read_rows stream
144+ """
128145 async for resp in await stream :
146+ # extract proto from proto-plus wrapper
129147 resp = resp ._pb
130148
149+ # handle last_scanned_row_key packets, sent when server
150+ # has scanned past the end of the row range
131151 if resp .last_scanned_row_key :
132152 if (
133153 self ._last_yielded_row_key is not None
@@ -137,7 +157,7 @@ async def chunk_stream(self, stream):
137157 self ._last_yielded_row_key = resp .last_scanned_row_key
138158
139159 current_key = None
140-
160+ # process each chunk in the response
141161 for c in resp .chunks :
142162 if current_key is None :
143163 current_key = c .row_key
@@ -154,12 +174,18 @@ async def chunk_stream(self, stream):
154174 if c .reset_row :
155175 current_key = None
156176 elif c .commit_row :
177+ # update row state after each commit
157178 self ._last_yielded_row_key = current_key
158179 if self ._remaining_count is not None :
159180 self ._remaining_count -= 1
160181
161182 @staticmethod
162- async def merge_rows (chunks ):
183+ async def merge_rows (chunks : AsyncGenerator [ReadRowsResponsePB .CellChunk , None ] | None ):
184+ """
185+ Merge chunks into rows
186+ """
187+ if chunks is None :
188+ return
163189 it = chunks .__aiter__ ()
164190 # For each row
165191 while True :
@@ -173,27 +199,17 @@ async def merge_rows(chunks):
173199 if not row_key :
174200 raise InvalidChunk ("first row chunk is missing key" )
175201
176- # Cells
177202 cells = []
178203
179204 # shared per cell storage
180- family : str | None = None
181- qualifier : bytes | None = None
205+ family : str | None = None
206+ qualifier : bytes | None = None
182207
183208 try :
184209 # for each cell
185210 while True :
186211 if c .reset_row :
187- if (
188- c .row_key
189- or c .HasField ("family_name" )
190- or c .HasField ("qualifier" )
191- or c .timestamp_micros
192- or c .labels
193- or c .value
194- ):
195- raise InvalidChunk ("reset row with data" )
196- raise _ResetRow ()
212+ raise _ResetRow (c )
197213 k = c .row_key
198214 f = c .family_name .value
199215 q = c .qualifier .value if c .HasField ("qualifier" ) else None
@@ -242,28 +258,31 @@ async def merge_rows(chunks):
242258 raise InvalidChunk ("row key changed mid cell" )
243259
244260 if c .reset_row :
245- if (
246- c .row_key
247- or c .HasField ("family_name" )
248- or c .HasField ("qualifier" )
249- or c .timestamp_micros
250- or c .labels
251- or c .value
252- ):
253- raise InvalidChunk ("reset_row with non-empty value" )
254- raise _ResetRow ()
261+ raise _ResetRow (c )
255262 buffer .append (c .value )
256263 value = b"" .join (buffer )
257264 if family is None :
258265 raise InvalidChunk ("missing family" )
259266 if qualifier is None :
260267 raise InvalidChunk ("missing qualifier" )
261- cells .append (Cell (value , row_key , family , qualifier , ts , list (labels )))
268+ cells .append (
269+ Cell (value , row_key , family , qualifier , ts , list (labels ))
270+ )
262271 if c .commit_row :
263272 yield Row (row_key , cells )
264273 break
265274 c = await it .__anext__ ()
266- except _ResetRow :
275+ except _ResetRow as e :
276+ c = e .chunk
277+ if (
278+ c .row_key
279+ or c .HasField ("family_name" )
280+ or c .HasField ("qualifier" )
281+ or c .timestamp_micros
282+ or c .labels
283+ or c .value
284+ ):
285+ raise InvalidChunk ("reset row with data" )
267286 continue
268287 except StopAsyncIteration :
269288 raise InvalidChunk ("premature end of stream" )
@@ -301,15 +320,19 @@ def _revise_request_rowset(
301320 if start_key is None or start_key <= last_seen_row_key :
302321 # replace start key with last seen
303322 new_range .start_key_open = last_seen_row_key
304- new_range .start_key_closed = None
323+ new_range .start_key_closed = b''
305324 adjusted_ranges .append (new_range )
306325 if len (adjusted_keys ) == 0 and len (adjusted_ranges ) == 0 :
307326 # if the query is empty after revision, raise an exception
308327 # this will avoid an unwanted full table scan
309328 raise _RowSetComplete ()
310329 return RowSetPB (row_keys = adjusted_keys , row_ranges = adjusted_ranges )
311330
312- def _raise_retry_error (self , transient_errors ):
331+ def _raise_retry_error (self , transient_errors : list [Exception ]) -> None :
332+ """
333+ If the retryable deadline is hit, wrap the raised exception
334+ in a RetryExceptionGroup
335+ """
313336 timeout_value = self .operation_timeout
314337 timeout_str = f" of { timeout_value :.1f} s" if timeout_value is not None else ""
315338 error_str = f"operation_timeout{ timeout_str } exceeded"
0 commit comments