Skip to content

Commit 9d39b43

Browse files
perf: Defer some data uploads to execution time
1 parent c46ad06 commit 9d39b43

File tree

5 files changed

+123
-58
lines changed

5 files changed

+123
-58
lines changed

bigframes/core/array_value.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,17 @@ def from_table(
133133
ordering=ordering,
134134
n_rows=n_rows,
135135
)
136+
return cls.from_bq_data_source(source_def, scan_list, session)
137+
138+
@classmethod
139+
def from_bq_data_source(
140+
cls,
141+
source: nodes.BigqueryDataSource,
142+
scan_list: nodes.ScanList,
143+
session: Session,
144+
):
136145
node = nodes.ReadTableNode(
137-
source=source_def,
146+
source=source,
138147
scan_list=scan_list,
139148
table_session=session,
140149
)

bigframes/core/nodes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,11 @@ def project(
614614
result = ScanList((self.items[:1]))
615615
return result
616616

617+
def append(
618+
self, source_id: str, dtype: bigframes.dtypes.Dtype, id: identifiers.ColumnId
619+
) -> ScanList:
620+
return ScanList((*self.items, ScanItem(id, dtype, source_id)))
621+
617622

618623
@dataclasses.dataclass(frozen=True, eq=False)
619624
class ReadLocalNode(LeafNode):

bigframes/session/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -244,13 +244,6 @@ def __init__(
244244
self._temp_storage_manager = (
245245
self._session_resource_manager or self._anon_dataset_manager
246246
)
247-
self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor(
248-
bqclient=self._clients_provider.bqclient,
249-
bqstoragereadclient=self._clients_provider.bqstoragereadclient,
250-
storage_manager=self._temp_storage_manager,
251-
strictly_ordered=self._strictly_ordered,
252-
metrics=self._metrics,
253-
)
254247
self._loader = bigframes.session.loader.GbqDataLoader(
255248
session=self,
256249
bqclient=self._clients_provider.bqclient,
@@ -261,6 +254,14 @@ def __init__(
261254
force_total_order=self._strictly_ordered,
262255
metrics=self._metrics,
263256
)
257+
self._executor: executor.Executor = bq_caching_executor.BigQueryCachingExecutor(
258+
bqclient=self._clients_provider.bqclient,
259+
bqstoragereadclient=self._clients_provider.bqstoragereadclient,
260+
storage_manager=self._temp_storage_manager,
261+
loader=self._loader,
262+
strictly_ordered=self._strictly_ordered,
263+
metrics=self._metrics,
264+
)
264265

265266
def __del__(self):
266267
"""Automatic cleanup of internal resources."""

bigframes/session/bq_caching_executor.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import bigframes.dtypes
3636
import bigframes.exceptions as bfe
3737
import bigframes.features
38-
from bigframes.session import executor, local_scan_executor, read_api_execution
38+
from bigframes.session import executor, loader, local_scan_executor, read_api_execution
3939
import bigframes.session._io.bigquery as bq_io
4040
import bigframes.session.metrics
4141
import bigframes.session.planner
@@ -47,6 +47,7 @@
4747
MAX_SUBTREE_FACTORINGS = 5
4848
_MAX_CLUSTER_COLUMNS = 4
4949
MAX_SMALL_RESULT_BYTES = 10 * 1024 * 1024 * 1024 # 10G
50+
MAX_INLINE_BYTES = 5000
5051

5152

5253
class BigQueryCachingExecutor(executor.Executor):
@@ -63,6 +64,7 @@ def __init__(
6364
bqclient: bigquery.Client,
6465
storage_manager: bigframes.session.temporary_storage.TemporaryStorageManager,
6566
bqstoragereadclient: google.cloud.bigquery_storage_v1.BigQueryReadClient,
67+
loader: loader.GbqDataLoader,
6668
*,
6769
strictly_ordered: bool = True,
6870
metrics: Optional[bigframes.session.metrics.ExecutionMetrics] = None,
@@ -72,6 +74,7 @@ def __init__(
7274
self.compiler: bigframes.core.compile.SQLCompiler = (
7375
bigframes.core.compile.SQLCompiler()
7476
)
77+
self.loader = loader
7578
self.strictly_ordered: bool = strictly_ordered
7679
self._cached_executions: weakref.WeakKeyDictionary[
7780
nodes.BigFrameNode, nodes.BigFrameNode
@@ -437,6 +440,31 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue):
437440
if not did_cache:
438441
return
439442

443+
def _upload_large_local_sources(self, root: nodes.BigFrameNode):
444+
for leaf in root.unique_nodes():
445+
if isinstance(leaf, nodes.ReadLocalNode):
446+
if leaf.local_data_source.metadata.total_bytes > MAX_INLINE_BYTES:
447+
self._cache_local_table(leaf)
448+
449+
def _cache_local_table(self, node: nodes.ReadLocalNode):
450+
offsets_col = bigframes.core.guid.generate_guid()
451+
# TODO: Best effort go through available upload paths
452+
bq_data_source = self.loader.write_data(
453+
node.local_data_source, offsets_col=offsets_col
454+
)
455+
scan_list = node.scan_list
456+
if node.offsets_col is not None:
457+
scan_list = scan_list.append(
458+
offsets_col, bigframes.dtypes.INT_DTYPE, node.offsets_col
459+
)
460+
cache_node = nodes.CachedTableNode(
461+
source=bq_data_source,
462+
scan_list=scan_list,
463+
table_session=self.loader._session,
464+
original_node=node,
465+
)
466+
self._cached_executions[node] = cache_node
467+
440468
def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool:
441469
# TODO: If query fails, retry with lower complexity limit
442470
selection = tree_properties.select_cache_target(

bigframes/session/loader.py

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import pandas
4444
import pyarrow as pa
4545

46-
from bigframes.core import guid, local_data, utils
46+
from bigframes.core import guid, identifiers, local_data, nodes, ordering, utils
4747
import bigframes.core as core
4848
import bigframes.core.blocks as blocks
4949
import bigframes.core.schema as schemata
@@ -183,35 +183,59 @@ def read_pandas(
183183
)
184184
managed_data = local_data.ManagedArrowTable.from_pandas(prepared_df)
185185

186+
block = blocks.Block(
187+
self.read_managed_data(managed_data, method=method, api_name=api_name),
188+
index_columns=idx_cols,
189+
column_labels=pandas_dataframe.columns,
190+
index_labels=pandas_dataframe.index.names,
191+
)
192+
return dataframe.DataFrame(block)
193+
194+
def read_managed_data(
195+
self,
196+
data: local_data.ManagedArrowTable,
197+
method: Literal["load", "stream", "write"],
198+
api_name: str,
199+
) -> core.ArrayValue:
200+
offsets_col = guid.generate_guid("upload_offsets_")
186201
if method == "load":
187-
array_value = self.load_data(managed_data, api_name=api_name)
202+
gbq_source = self.load_data(
203+
data, offsets_col=offsets_col, api_name=api_name
204+
)
188205
elif method == "stream":
189-
array_value = self.stream_data(managed_data)
206+
gbq_source = self.stream_data(data, offsets_col=offsets_col)
190207
elif method == "write":
191-
array_value = self.write_data(managed_data)
208+
gbq_source = self.write_data(data, offsets_col=offsets_col)
192209
else:
193210
raise ValueError(f"Unsupported read method {method}")
194211

195-
block = blocks.Block(
196-
array_value,
197-
index_columns=idx_cols,
198-
column_labels=pandas_dataframe.columns,
199-
index_labels=pandas_dataframe.index.names,
212+
return core.ArrayValue.from_bq_data_source(
213+
source=gbq_source,
214+
scan_list=nodes.ScanList(
215+
tuple(
216+
nodes.ScanItem(
217+
identifiers.ColumnId(item.column), item.dtype, item.column
218+
)
219+
for item in data.schema.items
220+
)
221+
),
222+
session=self._session,
200223
)
201-
return dataframe.DataFrame(block)
202224

203225
def load_data(
204-
self, data: local_data.ManagedArrowTable, api_name: Optional[str] = None
205-
) -> core.ArrayValue:
226+
self,
227+
data: local_data.ManagedArrowTable,
228+
offsets_col: str,
229+
api_name: Optional[str] = None,
230+
) -> nodes.BigqueryDataSource:
206231
"""Load managed data into bigquery"""
207-
ordering_col = guid.generate_guid("load_offsets_")
208232

209233
# JSON support incomplete
210234
for item in data.schema.items:
211235
_validate_dtype_can_load(item.column, item.dtype)
212236

213237
schema_w_offsets = data.schema.append(
214-
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
238+
schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE)
215239
)
216240
bq_schema = schema_w_offsets.to_bigquery(_LOAD_JOB_TYPE_OVERRIDES)
217241

@@ -222,13 +246,13 @@ def load_data(
222246
job_config.labels = {"bigframes-api": api_name}
223247

224248
load_table_destination = self._storage_manager.create_temp_table(
225-
bq_schema, [ordering_col]
249+
bq_schema, [offsets_col]
226250
)
227251

228252
buffer = io.BytesIO()
229253
data.to_parquet(
230254
buffer,
231-
offsets_col=ordering_col,
255+
offsets_col=offsets_col,
232256
geo_format="wkt",
233257
duration_type="duration",
234258
json_type="string",
@@ -240,23 +264,24 @@ def load_data(
240264
self._start_generic_job(load_job)
241265
# must get table metadata after load job for accurate metadata
242266
destination_table = self._bqclient.get_table(load_table_destination)
243-
return core.ArrayValue.from_table(
244-
table=destination_table,
245-
schema=schema_w_offsets,
246-
session=self._session,
247-
offsets_col=ordering_col,
248-
n_rows=data.data.num_rows,
249-
).drop_columns([ordering_col])
267+
return nodes.BigqueryDataSource(
268+
nodes.GbqTable.from_table(destination_table),
269+
ordering=ordering.TotalOrdering.from_offset_col(offsets_col),
270+
n_rows=destination_table.num_rows,
271+
)
250272

251-
def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
273+
def stream_data(
274+
self,
275+
data: local_data.ManagedArrowTable,
276+
offsets_col: str,
277+
) -> nodes.BigqueryDataSource:
252278
"""Load managed data into bigquery"""
253-
ordering_col = guid.generate_guid("stream_offsets_")
254279
schema_w_offsets = data.schema.append(
255-
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
280+
schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE)
256281
)
257282
bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES)
258283
load_table_destination = self._storage_manager.create_temp_table(
259-
bq_schema, [ordering_col]
284+
bq_schema, [offsets_col]
260285
)
261286

262287
rows = data.itertuples(
@@ -275,24 +300,23 @@ def stream_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
275300
f"Problem loading at least one row from DataFrame: {errors}. {constants.FEEDBACK_LINK}"
276301
)
277302
destination_table = self._bqclient.get_table(load_table_destination)
278-
return core.ArrayValue.from_table(
279-
table=destination_table,
280-
schema=schema_w_offsets,
281-
session=self._session,
282-
offsets_col=ordering_col,
283-
n_rows=data.data.num_rows,
284-
).drop_columns([ordering_col])
303+
return nodes.BigqueryDataSource(
304+
nodes.GbqTable.from_table(destination_table),
305+
ordering=ordering.TotalOrdering.from_offset_col(offsets_col),
306+
n_rows=destination_table.num_rows,
307+
)
285308

286-
def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
309+
def write_data(
310+
self,
311+
data: local_data.ManagedArrowTable,
312+
offsets_col: str,
313+
) -> nodes.BigqueryDataSource:
287314
"""Load managed data into bigquery"""
288-
ordering_col = guid.generate_guid("stream_offsets_")
289315
schema_w_offsets = data.schema.append(
290-
schemata.SchemaItem(ordering_col, bigframes.dtypes.INT_DTYPE)
316+
schemata.SchemaItem(offsets_col, bigframes.dtypes.INT_DTYPE)
291317
)
292318
bq_schema = schema_w_offsets.to_bigquery(_STREAM_JOB_TYPE_OVERRIDES)
293-
bq_table_ref = self._storage_manager.create_temp_table(
294-
bq_schema, [ordering_col]
295-
)
319+
bq_table_ref = self._storage_manager.create_temp_table(bq_schema, [offsets_col])
296320

297321
requested_stream = bq_storage_types.stream.WriteStream()
298322
requested_stream.type_ = bq_storage_types.stream.WriteStream.Type.COMMITTED # type: ignore
@@ -304,7 +328,7 @@ def write_data(self, data: local_data.ManagedArrowTable) -> core.ArrayValue:
304328

305329
def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]:
306330
schema, batches = data.to_arrow(
307-
offsets_col=ordering_col, duration_type="int"
331+
offsets_col=offsets_col, duration_type="int"
308332
)
309333
offset = 0
310334
for batch in batches:
@@ -330,13 +354,11 @@ def request_gen() -> Generator[bq_storage_types.AppendRowsRequest, None, None]:
330354
assert response.row_count == data.data.num_rows
331355

332356
destination_table = self._bqclient.get_table(bq_table_ref)
333-
return core.ArrayValue.from_table(
334-
table=destination_table,
335-
schema=schema_w_offsets,
336-
session=self._session,
337-
offsets_col=ordering_col,
338-
n_rows=data.data.num_rows,
339-
).drop_columns([ordering_col])
357+
return nodes.BigqueryDataSource(
358+
nodes.GbqTable.from_table(destination_table),
359+
ordering=ordering.TotalOrdering.from_offset_col(offsets_col),
360+
n_rows=destination_table.num_rows,
361+
)
340362

341363
def _start_generic_job(self, job: formatting_helpers.GenericJob):
342364
if bigframes.options.display.progress_bar is not None:
@@ -533,7 +555,7 @@ def read_gbq_table(
533555
if not primary_key:
534556
array_value = array_value.order_by(
535557
[
536-
bigframes.core.ordering.OrderingExpression(
558+
ordering.OrderingExpression(
537559
bigframes.operations.RowKey().as_expr(
538560
*(id for id in array_value.column_ids)
539561
),

0 commit comments

Comments
 (0)