Skip to content

Commit 67d8760

Browse files
perf: Rechunk result pages client side (#1680)
1 parent 5c125c9 commit 67d8760

File tree

8 files changed

+182
-43
lines changed

8 files changed

+182
-43
lines changed

bigframes/core/blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,10 @@ def to_pandas_batches(
586586
self.expr,
587587
ordered=True,
588588
use_explicit_destination=allow_large_results,
589-
page_size=page_size,
590-
max_results=max_results,
591589
)
592-
for df in execute_result.to_pandas_batches():
590+
for df in execute_result.to_pandas_batches(
591+
page_size=page_size, max_results=max_results
592+
):
593593
self._copy_index_to_pandas(df)
594594
if squeeze:
595595
yield df.squeeze(axis=1)

bigframes/core/pyarrow_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Iterable, Iterator
15+
16+
import pyarrow as pa
17+
18+
19+
class BatchBuffer:
20+
"""
21+
FIFO buffer of pyarrow Record batches
22+
23+
Not thread-safe.
24+
"""
25+
26+
def __init__(self):
27+
self._buffer: list[pa.RecordBatch] = []
28+
self._buffer_size: int = 0
29+
30+
def __len__(self):
31+
return self._buffer_size
32+
33+
def append_batch(self, batch: pa.RecordBatch) -> None:
34+
self._buffer.append(batch)
35+
self._buffer_size += batch.num_rows
36+
37+
def take_as_batches(self, n: int) -> tuple[pa.RecordBatch, ...]:
38+
if n > len(self):
39+
raise ValueError(f"Cannot take {n} rows, only {len(self)} rows in buffer.")
40+
rows_taken = 0
41+
sub_batches: list[pa.RecordBatch] = []
42+
while rows_taken < n:
43+
batch = self._buffer.pop(0)
44+
if batch.num_rows > (n - rows_taken):
45+
sub_batches.append(batch.slice(length=n - rows_taken))
46+
self._buffer.insert(0, batch.slice(offset=n - rows_taken))
47+
rows_taken += n - rows_taken
48+
else:
49+
sub_batches.append(batch)
50+
rows_taken += batch.num_rows
51+
52+
self._buffer_size -= n
53+
return tuple(sub_batches)
54+
55+
def take_rechunked(self, n: int) -> pa.RecordBatch:
56+
return (
57+
pa.Table.from_batches(self.take_as_batches(n))
58+
.combine_chunks()
59+
.to_batches()[0]
60+
)
61+
62+
63+
def chunk_by_row_count(
64+
batches: Iterable[pa.RecordBatch], page_size: int
65+
) -> Iterator[tuple[pa.RecordBatch, ...]]:
66+
buffer = BatchBuffer()
67+
for batch in batches:
68+
buffer.append_batch(batch)
69+
while len(buffer) >= page_size:
70+
yield buffer.take_as_batches(page_size)
71+
72+
# emit final page, maybe smaller
73+
if len(buffer) > 0:
74+
yield buffer.take_as_batches(len(buffer))
75+
76+
77+
def truncate_pyarrow_iterable(
78+
batches: Iterable[pa.RecordBatch], max_results: int
79+
) -> Iterator[pa.RecordBatch]:
80+
total_yielded = 0
81+
for batch in batches:
82+
if batch.num_rows >= (max_results - total_yielded):
83+
yield batch.slice(length=max_results - total_yielded)
84+
return
85+
else:
86+
yield batch
87+
total_yielded += batch.num_rows

bigframes/session/_io/bigquery/__init__.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,6 @@ def start_query_with_client(
222222
job_config: bigquery.job.QueryJobConfig,
223223
location: Optional[str] = None,
224224
project: Optional[str] = None,
225-
max_results: Optional[int] = None,
226-
page_size: Optional[int] = None,
227225
timeout: Optional[float] = None,
228226
api_name: Optional[str] = None,
229227
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
@@ -244,8 +242,6 @@ def start_query_with_client(
244242
location=location,
245243
project=project,
246244
api_timeout=timeout,
247-
page_size=page_size,
248-
max_results=max_results,
249245
)
250246
if metrics is not None:
251247
metrics.count_job_stats(row_iterator=results_iterator)
@@ -267,14 +263,10 @@ def start_query_with_client(
267263
if opts.progress_bar is not None and not query_job.configuration.dry_run:
268264
results_iterator = formatting_helpers.wait_for_query_job(
269265
query_job,
270-
max_results=max_results,
271266
progress_bar=opts.progress_bar,
272-
page_size=page_size,
273267
)
274268
else:
275-
results_iterator = query_job.result(
276-
max_results=max_results, page_size=page_size
277-
)
269+
results_iterator = query_job.result()
278270

279271
if metrics is not None:
280272
metrics.count_job_stats(query_job=query_job)

bigframes/session/bq_caching_executor.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ def execute(
106106
*,
107107
ordered: bool = True,
108108
use_explicit_destination: Optional[bool] = None,
109-
page_size: Optional[int] = None,
110-
max_results: Optional[int] = None,
111109
) -> executor.ExecuteResult:
112110
if use_explicit_destination is None:
113111
use_explicit_destination = bigframes.options.bigquery.allow_large_results
@@ -127,8 +125,6 @@ def execute(
127125
return self._execute_plan(
128126
plan,
129127
ordered=ordered,
130-
page_size=page_size,
131-
max_results=max_results,
132128
destination=destination_table,
133129
)
134130

@@ -281,8 +277,6 @@ def _run_execute_query(
281277
sql: str,
282278
job_config: Optional[bq_job.QueryJobConfig] = None,
283279
api_name: Optional[str] = None,
284-
page_size: Optional[int] = None,
285-
max_results: Optional[int] = None,
286280
query_with_job: bool = True,
287281
) -> Tuple[bq_table.RowIterator, Optional[bigquery.QueryJob]]:
288282
"""
@@ -303,8 +297,6 @@ def _run_execute_query(
303297
sql,
304298
job_config=job_config,
305299
api_name=api_name,
306-
max_results=max_results,
307-
page_size=page_size,
308300
metrics=self.metrics,
309301
query_with_job=query_with_job,
310302
)
@@ -479,16 +471,13 @@ def _execute_plan(
479471
self,
480472
plan: nodes.BigFrameNode,
481473
ordered: bool,
482-
page_size: Optional[int] = None,
483-
max_results: Optional[int] = None,
484474
destination: Optional[bq_table.TableReference] = None,
485475
peek: Optional[int] = None,
486476
):
487477
"""Just execute whatever plan as is, without further caching or decomposition."""
488478

489479
# First try to execute fast-paths
490-
# TODO: Allow page_size and max_results by rechunking/truncating results
491-
if (not page_size) and (not max_results) and (not destination) and (not peek):
480+
if (not destination) and (not peek):
492481
for semi_executor in self._semi_executors:
493482
maybe_result = semi_executor.execute(plan, ordered=ordered)
494483
if maybe_result:
@@ -504,20 +493,12 @@ def _execute_plan(
504493
iterator, query_job = self._run_execute_query(
505494
sql=sql,
506495
job_config=job_config,
507-
page_size=page_size,
508-
max_results=max_results,
509496
query_with_job=(destination is not None),
510497
)
511498

512499
# Though we provide the read client, iterator may or may not use it based on what is efficient for the result
513500
def iterator_supplier():
514-
# Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154
515-
if iterator._page_size is not None or iterator.max_results is not None:
516-
return iterator.to_arrow_iterable(bqstorage_client=None)
517-
else:
518-
return iterator.to_arrow_iterable(
519-
bqstorage_client=self.bqstoragereadclient
520-
)
501+
return iterator.to_arrow_iterable(bqstorage_client=self.bqstoragereadclient)
521502

522503
if query_job:
523504
size_bytes = self.bqclient.get_table(query_job.destination).num_bytes

bigframes/session/executor.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import pyarrow
2626

2727
import bigframes.core
28+
from bigframes.core import pyarrow_utils
2829
import bigframes.core.schema
2930
import bigframes.session._io.pandas as io_pandas
3031

@@ -55,10 +56,28 @@ def to_arrow_table(self) -> pyarrow.Table:
5556
def to_pandas(self) -> pd.DataFrame:
5657
return io_pandas.arrow_to_pandas(self.to_arrow_table(), self.schema)
5758

58-
def to_pandas_batches(self) -> Iterator[pd.DataFrame]:
59+
def to_pandas_batches(
60+
self, page_size: Optional[int] = None, max_results: Optional[int] = None
61+
) -> Iterator[pd.DataFrame]:
62+
assert (page_size is None) or (page_size > 0)
63+
assert (max_results is None) or (max_results > 0)
64+
batch_iter: Iterator[
65+
Union[pyarrow.Table, pyarrow.RecordBatch]
66+
] = self.arrow_batches()
67+
if max_results is not None:
68+
batch_iter = pyarrow_utils.truncate_pyarrow_iterable(
69+
batch_iter, max_results
70+
)
71+
72+
if page_size is not None:
73+
batches_iter = pyarrow_utils.chunk_by_row_count(batch_iter, page_size)
74+
batch_iter = map(
75+
lambda batches: pyarrow.Table.from_batches(batches), batches_iter
76+
)
77+
5978
yield from map(
6079
functools.partial(io_pandas.arrow_to_pandas, schema=self.schema),
61-
self.arrow_batches(),
80+
batch_iter,
6281
)
6382

6483
def to_py_scalar(self):
@@ -107,8 +126,6 @@ def execute(
107126
*,
108127
ordered: bool = True,
109128
use_explicit_destination: Optional[bool] = False,
110-
page_size: Optional[int] = None,
111-
max_results: Optional[int] = None,
112129
) -> ExecuteResult:
113130
"""
114131
Execute the ArrayValue, storing the result to a temporary session-owned table.

bigframes/session/loader.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,6 @@ def _start_query(
906906
self,
907907
sql: str,
908908
job_config: Optional[google.cloud.bigquery.QueryJobConfig] = None,
909-
max_results: Optional[int] = None,
910909
timeout: Optional[float] = None,
911910
api_name: Optional[str] = None,
912911
) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]:
@@ -925,7 +924,6 @@ def _start_query(
925924
self._bqclient,
926925
sql,
927926
job_config=job_config,
928-
max_results=max_results,
929927
timeout=timeout,
930928
api_name=api_name,
931929
)

tests/unit/core/test_pyarrow_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
17+
import numpy as np
18+
import pyarrow as pa
19+
import pytest
20+
21+
from bigframes.core import pyarrow_utils
22+
23+
PA_TABLE = pa.table({f"col_{i}": np.random.rand(1000) for i in range(10)})
24+
25+
# 17, 3, 929 coprime
26+
N = 17
27+
MANY_SMALL_BATCHES = PA_TABLE.to_batches(max_chunksize=3)
28+
FEW_BIG_BATCHES = PA_TABLE.to_batches(max_chunksize=929)
29+
30+
31+
@pytest.mark.parametrize(
32+
["batches", "page_size"],
33+
[
34+
(MANY_SMALL_BATCHES, N),
35+
(FEW_BIG_BATCHES, N),
36+
],
37+
)
38+
def test_chunk_by_row_count(batches, page_size):
39+
results = list(pyarrow_utils.chunk_by_row_count(batches, page_size=page_size))
40+
41+
for i, batches in enumerate(results):
42+
if i != len(results) - 1:
43+
assert sum(map(lambda x: x.num_rows, batches)) == page_size
44+
else:
45+
# final page can be smaller
46+
assert sum(map(lambda x: x.num_rows, batches)) <= page_size
47+
48+
reconstructed = pa.Table.from_batches(itertools.chain.from_iterable(results))
49+
assert reconstructed.equals(PA_TABLE)
50+
51+
52+
@pytest.mark.parametrize(
53+
["batches", "max_rows"],
54+
[
55+
(MANY_SMALL_BATCHES, N),
56+
(FEW_BIG_BATCHES, N),
57+
],
58+
)
59+
def test_truncate_pyarrow_iterable(batches, max_rows):
60+
results = list(
61+
pyarrow_utils.truncate_pyarrow_iterable(batches, max_results=max_rows)
62+
)
63+
64+
reconstructed = pa.Table.from_batches(results)
65+
assert reconstructed.equals(PA_TABLE.slice(length=max_rows))

tests/unit/session/test_io_bigquery.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,11 @@ def test_add_and_trim_labels_length_limit_met():
199199

200200

201201
@pytest.mark.parametrize(
202-
("max_results", "timeout", "api_name"),
203-
[(None, None, None), (100, 30.0, "test_api")],
202+
("timeout", "api_name"),
203+
[(None, None), (30.0, "test_api")],
204204
)
205205
def test_start_query_with_client_labels_length_limit_met(
206-
mock_bq_client, max_results, timeout, api_name
206+
mock_bq_client, timeout, api_name
207207
):
208208
sql = "select * from abc"
209209
cur_labels = {
@@ -230,7 +230,6 @@ def test_start_query_with_client_labels_length_limit_met(
230230
mock_bq_client,
231231
sql,
232232
job_config,
233-
max_results=max_results,
234233
timeout=timeout,
235234
api_name=api_name,
236235
)

0 commit comments

Comments
 (0)