Skip to content

Commit cd2bb1b

Browse files
perf: Cache results opportunistically
1 parent 4c5dee5 commit cd2bb1b

File tree

7 files changed

+182
-88
lines changed

7 files changed

+182
-88
lines changed

bigframes/core/compile/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,13 @@
1414
from __future__ import annotations
1515

1616
from bigframes.core.compile.api import SQLCompiler, test_only_ibis_inferred_schema
17+
from bigframes.core.compile.compiler import compile_sql
18+
from bigframes.core.compile.configs import CompileRequest, CompileResult
1719

1820
__all__ = [
1921
"SQLCompiler",
2022
"test_only_ibis_inferred_schema",
23+
"compile_sql",
24+
"CompileRequest",
25+
"CompileResult",
2126
]

bigframes/core/pyarrow_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import itertools
15+
from typing import Iterable, Iterator, Optional
16+
17+
import pyarrow as pa
18+
19+
20+
def peek_batches(
21+
batch_iter: Iterable[pa.RecordBatch], max_bytes: int
22+
) -> tuple[Iterator[pa.RecordBatch], Optional[tuple[pa.RecordBatch, ...]]]:
23+
"""
24+
Try to peek a pyarrow batch iterable. If greater than max_bytes, give up.
25+
26+
Will consume max_bytes + one batch of memory at worst.
27+
"""
28+
batch_list = []
29+
current_bytes = 0
30+
for batch in batch_iter:
31+
batch_list.append(batch)
32+
current_bytes += batch.nbytes
33+
34+
if current_bytes > max_bytes:
35+
return itertools.chain(batch_list, batch_iter), None
36+
37+
return iter(batch_list), tuple(batch_list)

bigframes/session/bq_caching_executor.py

Lines changed: 103 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
import google.cloud.bigquery.job as bq_job
2626
import google.cloud.bigquery.table as bq_table
2727
import google.cloud.bigquery_storage_v1
28+
import pyarrow as pa
2829

2930
import bigframes.core
30-
from bigframes.core import rewrite
31-
import bigframes.core.compile
31+
from bigframes.core import compile, local_data, pyarrow_utils, rewrite
3232
import bigframes.core.guid
3333
import bigframes.core.nodes as nodes
3434
import bigframes.core.ordering as order
@@ -70,9 +70,6 @@ def __init__(
7070
):
7171
self.bqclient = bqclient
7272
self.storage_manager = storage_manager
73-
self.compiler: bigframes.core.compile.SQLCompiler = (
74-
bigframes.core.compile.SQLCompiler()
75-
)
7673
self.strictly_ordered: bool = strictly_ordered
7774
self._cached_executions: weakref.WeakKeyDictionary[
7875
nodes.BigFrameNode, nodes.BigFrameNode
@@ -97,8 +94,11 @@ def to_sql(
9794
) -> str:
9895
if offset_column:
9996
array_value, _ = array_value.promote_offsets()
100-
node = self.logical_plan(array_value.node) if enable_cache else array_value.node
101-
return self.compiler.compile(node, ordered=ordered)
97+
node = (
98+
self.simplify_plan(array_value.node) if enable_cache else array_value.node
99+
)
100+
compiled = compile.compile_sql(compile.CompileRequest(node, sort_rows=ordered))
101+
return compiled.sql
102102

103103
def execute(
104104
self,
@@ -115,7 +115,6 @@ def execute(
115115
if bigframes.options.compute.enable_multi_query_execution:
116116
self._simplify_with_caching(array_value)
117117

118-
plan = self.logical_plan(array_value.node)
119118
# Use explicit destination to avoid 10GB limit of temporary table
120119
destination_table = (
121120
self.storage_manager.create_temp_table(
@@ -125,7 +124,7 @@ def execute(
125124
else None
126125
)
127126
return self._execute_plan(
128-
plan,
127+
array_value.node,
129128
ordered=ordered,
130129
page_size=page_size,
131130
max_results=max_results,
@@ -224,7 +223,7 @@ def peek(
224223
"""
225224
A 'peek' efficiently accesses a small number of rows in the dataframe.
226225
"""
227-
plan = self.logical_plan(array_value.node)
226+
plan = self.simplify_plan(array_value.node)
228227
if not tree_properties.can_fast_peek(plan):
229228
msg = bfe.format_message("Peeking this value cannot be done efficiently.")
230229
warnings.warn(msg)
@@ -240,7 +239,7 @@ def peek(
240239
)
241240

242241
return self._execute_plan(
243-
plan, ordered=False, destination=destination_table, peek=n_rows
242+
array_value.node, ordered=False, destination=destination_table, peek=n_rows
244243
)
245244

246245
def cached(
@@ -329,10 +328,10 @@ def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
329328
# Once rewriting is available, will want to rewrite before
330329
# evaluating execution cost.
331330
return tree_properties.is_trivially_executable(
332-
self.logical_plan(array_value.node)
331+
self.simplify_plan(array_value.node)
333332
)
334333

335-
def logical_plan(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode:
334+
def simplify_plan(self, root: nodes.BigFrameNode) -> nodes.BigFrameNode:
336335
"""
337336
Apply universal logical simplifications that are helpful regardless of engine.
338337
"""
@@ -345,29 +344,35 @@ def _cache_with_cluster_cols(
345344
self, array_value: bigframes.core.ArrayValue, cluster_cols: Sequence[str]
346345
):
347346
"""Executes the query and uses the resulting table to rewrite future executions."""
348-
349-
sql, schema, ordering_info = self.compiler.compile_raw(
350-
self.logical_plan(array_value.node)
347+
plan = self.simplify_plan(array_value.node)
348+
compiled = compile.compile_sql(
349+
compile.CompileRequest(
350+
plan, sort_rows=False, materialize_all_order_keys=True
351+
)
351352
)
352353
tmp_table = self._sql_as_cached_temp_table(
353-
sql,
354-
schema,
355-
cluster_cols=bq_io.select_cluster_cols(schema, cluster_cols),
354+
compiled.sql,
355+
compiled.sql_schema,
356+
cluster_cols=bq_io.select_cluster_cols(compiled.sql_schema, cluster_cols),
356357
)
357358
cached_replacement = array_value.as_cached(
358359
cache_table=self.bqclient.get_table(tmp_table),
359-
ordering=ordering_info,
360+
ordering=compiled.row_order,
360361
).node
361362
self._cached_executions[array_value.node] = cached_replacement
362363

363364
def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
364365
"""Executes the query and uses the resulting table to rewrite future executions."""
365366
offset_column = bigframes.core.guid.generate_guid("bigframes_offsets")
366367
w_offsets, offset_column = array_value.promote_offsets()
367-
sql = self.compiler.compile(self.logical_plan(w_offsets.node), ordered=False)
368+
compiled = compile.compile_sql(
369+
compile.CompileRequest(
370+
array_value.node, sort_rows=False, materialize_all_order_keys=True
371+
)
372+
)
368373

369374
tmp_table = self._sql_as_cached_temp_table(
370-
sql,
375+
compiled.sql,
371376
w_offsets.schema.to_bigquery(),
372377
cluster_cols=[offset_column],
373378
)
@@ -401,7 +406,7 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue):
401406
# Apply existing caching first
402407
for _ in range(MAX_SUBTREE_FACTORINGS):
403408
if (
404-
self.logical_plan(array_value.node).planning_complexity
409+
self.simplify_plan(array_value.node).planning_complexity
405410
< QUERY_COMPLEXITY_LIMIT
406411
):
407412
return
@@ -458,8 +463,8 @@ def _validate_result_schema(
458463
bq_schema: list[bigquery.SchemaField],
459464
):
460465
actual_schema = _sanitize(tuple(bq_schema))
461-
ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema(
462-
self.logical_plan(array_value.node)
466+
ibis_schema = compile.test_only_ibis_inferred_schema(
467+
self.simplify_plan(array_value.node)
463468
).to_bigquery()
464469
internal_schema = _sanitize(array_value.schema.to_bigquery())
465470
if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:
@@ -477,7 +482,7 @@ def _validate_result_schema(
477482

478483
def _execute_plan(
479484
self,
480-
plan: nodes.BigFrameNode,
485+
root: nodes.BigFrameNode,
481486
ordered: bool,
482487
page_size: Optional[int] = None,
483488
max_results: Optional[int] = None,
@@ -490,7 +495,9 @@ def _execute_plan(
490495
# TODO: Allow page_size and max_results by rechunking/truncating results
491496
if (not page_size) and (not max_results) and (not destination) and (not peek):
492497
for semi_executor in self._semi_executors:
493-
maybe_result = semi_executor.execute(plan, ordered=ordered)
498+
maybe_result = semi_executor.execute(
499+
self.simplify_plan(root), ordered=ordered
500+
)
494501
if maybe_result:
495502
return maybe_result
496503

@@ -500,31 +507,34 @@ def _execute_plan(
500507
# Use explicit destination to avoid 10GB limit of temporary table
501508
if destination is not None:
502509
job_config.destination = destination
503-
sql = self.compiler.compile(plan, ordered=ordered, limit=peek)
510+
compiled = compile.compile_sql(
511+
compile.CompileRequest(
512+
self.simplify_plan(root), sort_rows=ordered, peek_count=peek
513+
)
514+
)
504515
iterator, query_job = self._run_execute_query(
505-
sql=sql,
516+
sql=compiled.sql,
506517
job_config=job_config,
507518
page_size=page_size,
508519
max_results=max_results,
509520
query_with_job=(destination is not None),
510521
)
511522

512523
# Though we provide the read client, iterator may or may not use it based on what is efficient for the result
513-
def iterator_supplier():
514-
# Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154
515-
if iterator._page_size is not None or iterator.max_results is not None:
516-
return iterator.to_arrow_iterable(bqstorage_client=None)
517-
else:
518-
return iterator.to_arrow_iterable(
519-
bqstorage_client=self.bqstoragereadclient
520-
)
524+
# Workaround issue fixed by: https://github.com/googleapis/python-bigquery/pull/2154
525+
if iterator._page_size is not None or iterator.max_results is not None:
526+
batch_iterator = iterator.to_arrow_iterable(bqstorage_client=None)
527+
else:
528+
batch_iterator = iterator.to_arrow_iterable(
529+
bqstorage_client=self.bqstoragereadclient
530+
)
521531

522532
if query_job:
523-
size_bytes = self.bqclient.get_table(query_job.destination).num_bytes
533+
table = self.bqclient.get_table(query_job.destination)
524534
else:
525-
size_bytes = None
535+
table = None
526536

527-
if size_bytes is not None and size_bytes >= MAX_SMALL_RESULT_BYTES:
537+
if (table is not None) and (table.num_bytes or 0) >= MAX_SMALL_RESULT_BYTES:
528538
msg = bfe.format_message(
529539
"The query result size has exceeded 10 GB. In BigFrames 2.0 and "
530540
"later, you might need to manually set `allow_large_results=True` in "
@@ -536,14 +546,63 @@ def iterator_supplier():
536546
# Do not execute these validations outside of testing suite.
537547
if "PYTEST_CURRENT_TEST" in os.environ:
538548
self._validate_result_schema(
539-
bigframes.core.ArrayValue(plan), iterator.schema
549+
bigframes.core.ArrayValue(root), iterator.schema
550+
)
551+
552+
# if destination is set, this is an externally managed table, which may mutated, cannot use as cache
553+
if (
554+
(destination is not None)
555+
and (table is not None)
556+
and (compiled.row_order is not None)
557+
and (peek is None)
558+
):
559+
# Assumption: GBQ cached table uses field name as bq column name
560+
scan_list = nodes.ScanList(
561+
tuple(
562+
nodes.ScanItem(field.id, field.dtype, field.id.name)
563+
for field in root.fields
564+
)
565+
)
566+
cached_replacement = nodes.CachedTableNode(
567+
source=nodes.BigqueryDataSource(
568+
nodes.GbqTable.from_table(
569+
table, columns=tuple(f.id.name for f in root.fields)
570+
),
571+
ordering=compiled.row_order,
572+
n_rows=table.num_rows,
573+
),
574+
scan_list=scan_list,
575+
table_session=root.session,
576+
original_node=root,
577+
)
578+
self._cached_executions[root] = cached_replacement
579+
else: # no explicit destination, can maybe peek iterator
580+
# Assumption: GBQ cached table uses field name as bq column name
581+
scan_list = nodes.ScanList(
582+
tuple(
583+
nodes.ScanItem(field.id, field.dtype, field.id.name)
584+
for field in root.fields
585+
)
586+
)
587+
# Will increase when have auto-upload, 5000 is most want to inline
588+
batch_iterator, batches = pyarrow_utils.peek_batches(
589+
batch_iterator, max_bytes=5000
540590
)
591+
if batches:
592+
local_cached = nodes.ReadLocalNode(
593+
local_data_source=local_data.ManagedArrowTable.from_pyarrow(
594+
pa.Table.from_batches(batches)
595+
),
596+
scan_list=scan_list,
597+
session=root.session,
598+
)
599+
self._cached_executions[root] = local_cached
541600

542601
return executor.ExecuteResult(
543-
arrow_batches=iterator_supplier,
544-
schema=plan.schema,
602+
arrow_batches=batch_iterator,
603+
schema=root.schema,
545604
query_job=query_job,
546-
total_bytes=size_bytes,
605+
total_bytes=table.num_bytes if table else None,
547606
total_rows=iterator.total_rows,
548607
)
549608

bigframes/session/executor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import dataclasses
1919
import functools
2020
import itertools
21-
from typing import Callable, Iterator, Literal, Mapping, Optional, Sequence, Union
21+
from typing import Iterator, Literal, Mapping, Optional, Sequence, Union
2222

2323
from google.cloud import bigquery
2424
import pandas as pd
@@ -31,7 +31,7 @@
3131

3232
@dataclasses.dataclass(frozen=True)
3333
class ExecuteResult:
34-
arrow_batches: Callable[[], Iterator[pyarrow.RecordBatch]]
34+
arrow_batches: Iterator[pyarrow.RecordBatch]
3535
schema: bigframes.core.schema.ArraySchema
3636
query_job: Optional[bigquery.QueryJob] = None
3737
total_bytes: Optional[int] = None
@@ -41,7 +41,7 @@ def to_arrow_table(self) -> pyarrow.Table:
4141
# Need to provide schema if no result rows, as arrow can't infer
4242
# If ther are rows, it is safest to infer schema from batches.
4343
# Any discrepencies between predicted schema and actual schema will produce errors.
44-
batches = iter(self.arrow_batches())
44+
batches = iter(self.arrow_batches)
4545
peek_it = itertools.islice(batches, 0, 1)
4646
peek_value = list(peek_it)
4747
# TODO: Enforce our internal schema on the table for consistency
@@ -58,7 +58,7 @@ def to_pandas(self) -> pd.DataFrame:
5858
def to_pandas_batches(self) -> Iterator[pd.DataFrame]:
5959
yield from map(
6060
functools.partial(io_pandas.arrow_to_pandas, schema=self.schema),
61-
self.arrow_batches(),
61+
self.arrow_batches,
6262
)
6363

6464
def to_py_scalar(self):

0 commit comments

Comments
 (0)