Skip to content

Commit adc4d97

Browse files
TrevorBergeronGenesis929
authored andcommitted
perf: Add multi-query execution capability for complex dataframes (#427)
1 parent 9917e58 commit adc4d97

File tree

10 files changed

+403
-6
lines changed

10 files changed

+403
-6
lines changed

bigframes/_config/compute_options.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,11 @@ class ComputeOptions:
4040
bytes billed beyond this limit will fail (without incurring a
4141
charge). If unspecified, this will be set to your project default.
4242
See `maximum_bytes_billed <https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJobConfig#google_cloud_bigquery_job_QueryJobConfig_maximum_bytes_billed>`_.
43-
43+
enable_multi_query_execution (bool, Options):
44+
If enabled, large queries may be factored into multiple smaller queries
45+
in order to avoid generating queries that are too complex for the query
46+
engine to handle. However this comes at the cost of increase cost and latency.
4447
"""
4548

4649
maximum_bytes_billed: Optional[int] = None
50+
enable_multi_query_execution: bool = False

bigframes/core/blocks.py

+4
Original file line numberDiff line numberDiff line change
@@ -1873,6 +1873,10 @@ def cached(self, *, optimize_offsets=False, force: bool = False) -> Block:
18731873
expr = self.session._cache_with_cluster_cols(
18741874
self.expr, cluster_cols=self.index_columns
18751875
)
1876+
return self.swap_array_expr(expr)
1877+
1878+
def swap_array_expr(self, expr: core.ArrayValue) -> Block:
1879+
# TODO: Validate schema unchanged
18761880
return Block(
18771881
expr,
18781882
index_columns=self.index_columns,

bigframes/core/expression.py

+9
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
108108
def is_bijective(self) -> bool:
109109
return False
110110

111+
@property
112+
def is_identity(self) -> bool:
113+
"""True for identity operation that does not transform input."""
114+
return False
115+
111116

112117
@dataclasses.dataclass(frozen=True)
113118
class ScalarConstantExpression(Expression):
@@ -173,6 +178,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
173178
def is_bijective(self) -> bool:
174179
return True
175180

181+
@property
182+
def is_identity(self) -> bool:
183+
return True
184+
176185

177186
@dataclasses.dataclass(frozen=True)
178187
class OpExpression(Expression):

bigframes/core/nodes.py

+205-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from __future__ import annotations
1616

1717
import abc
18-
from dataclasses import dataclass, field, fields
18+
from dataclasses import dataclass, field, fields, replace
1919
import functools
2020
import itertools
2121
import typing
22-
from typing import Tuple
22+
from typing import Callable, Tuple
2323

2424
import pandas
2525

@@ -39,6 +39,10 @@
3939
import bigframes.session
4040

4141

42+
# A fixed number of variable to assume for overhead on some operations
43+
OVERHEAD_VARIABLES = 5
44+
45+
4246
@dataclass(frozen=True)
4347
class BigFrameNode:
4448
"""
@@ -102,6 +106,60 @@ def roots(self) -> typing.Set[BigFrameNode]:
102106
def schema(self) -> schemata.ArraySchema:
103107
...
104108

109+
@property
110+
@abc.abstractmethod
111+
def variables_introduced(self) -> int:
112+
"""
113+
Defines number of values created by the current node. Helps represent the "width" of a query
114+
"""
115+
...
116+
117+
@property
118+
def relation_ops_created(self) -> int:
119+
"""
120+
Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
121+
"""
122+
return 1
123+
124+
@property
125+
def joins(self) -> bool:
126+
"""
127+
Defines whether the node joins data.
128+
"""
129+
return False
130+
131+
@functools.cached_property
132+
def total_variables(self) -> int:
133+
return self.variables_introduced + sum(
134+
map(lambda x: x.total_variables, self.child_nodes)
135+
)
136+
137+
@functools.cached_property
138+
def total_relational_ops(self) -> int:
139+
return self.relation_ops_created + sum(
140+
map(lambda x: x.total_relational_ops, self.child_nodes)
141+
)
142+
143+
@functools.cached_property
144+
def total_joins(self) -> int:
145+
return int(self.joins) + sum(map(lambda x: x.total_joins, self.child_nodes))
146+
147+
@property
148+
def planning_complexity(self) -> int:
149+
"""
150+
Empirical heuristic measure of planning complexity.
151+
152+
Used to determine when to decompose overly complex computations. May require tuning.
153+
"""
154+
return self.total_variables * self.total_relational_ops * (1 + self.total_joins)
155+
156+
@abc.abstractmethod
157+
def transform_children(
158+
self, t: Callable[[BigFrameNode], BigFrameNode]
159+
) -> BigFrameNode:
160+
"""Apply a function to each child node."""
161+
...
162+
105163

106164
@dataclass(frozen=True)
107165
class UnaryNode(BigFrameNode):
@@ -115,6 +173,11 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
115173
def schema(self) -> schemata.ArraySchema:
116174
return self.child.schema
117175

176+
def transform_children(
177+
self, t: Callable[[BigFrameNode], BigFrameNode]
178+
) -> BigFrameNode:
179+
return replace(self, child=t(self.child))
180+
118181

119182
@dataclass(frozen=True)
120183
class JoinNode(BigFrameNode):
@@ -154,6 +217,22 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
154217
)
155218
return schemata.ArraySchema(items)
156219

220+
@functools.cached_property
221+
def variables_introduced(self) -> int:
222+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
223+
return OVERHEAD_VARIABLES
224+
225+
@property
226+
def joins(self) -> bool:
227+
return True
228+
229+
def transform_children(
230+
self, t: Callable[[BigFrameNode], BigFrameNode]
231+
) -> BigFrameNode:
232+
return replace(
233+
self, left_child=t(self.left_child), right_child=t(self.right_child)
234+
)
235+
157236

158237
@dataclass(frozen=True)
159238
class ConcatNode(BigFrameNode):
@@ -182,6 +261,16 @@ def schema(self) -> schemata.ArraySchema:
182261
)
183262
return schemata.ArraySchema(items)
184263

264+
@functools.cached_property
265+
def variables_introduced(self) -> int:
266+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
267+
return len(self.schema.items) + OVERHEAD_VARIABLES
268+
269+
def transform_children(
270+
self, t: Callable[[BigFrameNode], BigFrameNode]
271+
) -> BigFrameNode:
272+
return replace(self, children=tuple(t(child) for child in self.children))
273+
185274

186275
# Input Nodex
187276
@dataclass(frozen=True)
@@ -201,6 +290,16 @@ def roots(self) -> typing.Set[BigFrameNode]:
201290
def schema(self) -> schemata.ArraySchema:
202291
return self.data_schema
203292

293+
@functools.cached_property
294+
def variables_introduced(self) -> int:
295+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
296+
return len(self.schema.items) + 1
297+
298+
def transform_children(
299+
self, t: Callable[[BigFrameNode], BigFrameNode]
300+
) -> BigFrameNode:
301+
return self
302+
204303

205304
# TODO: Refactor to take raw gbq object reference
206305
@dataclass(frozen=True)
@@ -233,6 +332,20 @@ def schema(self) -> schemata.ArraySchema:
233332
)
234333
return schemata.ArraySchema(items)
235334

335+
@functools.cached_property
336+
def variables_introduced(self) -> int:
337+
return len(self.columns) + len(self.hidden_ordering_columns)
338+
339+
@property
340+
def relation_ops_created(self) -> int:
341+
# Assume worst case, where readgbq actually has baked in analytic operation to generate index
342+
return 2
343+
344+
def transform_children(
345+
self, t: Callable[[BigFrameNode], BigFrameNode]
346+
) -> BigFrameNode:
347+
return self
348+
236349

237350
# Unary nodes
238351
@dataclass(frozen=True)
@@ -252,6 +365,14 @@ def schema(self) -> schemata.ArraySchema:
252365
schemata.SchemaItem(self.col_id, bigframes.dtypes.INT_DTYPE)
253366
)
254367

368+
@property
369+
def relation_ops_created(self) -> int:
370+
return 2
371+
372+
@functools.cached_property
373+
def variables_introduced(self) -> int:
374+
return 1
375+
255376

256377
@dataclass(frozen=True)
257378
class FilterNode(UnaryNode):
@@ -264,6 +385,10 @@ def row_preserving(self) -> bool:
264385
def __hash__(self):
265386
return self._node_hash
266387

388+
@property
389+
def variables_introduced(self) -> int:
390+
return 1
391+
267392

268393
@dataclass(frozen=True)
269394
class OrderByNode(UnaryNode):
@@ -281,6 +406,15 @@ def __post_init__(self):
281406
def __hash__(self):
282407
return self._node_hash
283408

409+
@property
410+
def variables_introduced(self) -> int:
411+
return 0
412+
413+
@property
414+
def relation_ops_created(self) -> int:
415+
# Doesnt directly create any relational operations
416+
return 0
417+
284418

285419
@dataclass(frozen=True)
286420
class ReversedNode(UnaryNode):
@@ -290,6 +424,15 @@ class ReversedNode(UnaryNode):
290424
def __hash__(self):
291425
return self._node_hash
292426

427+
@property
428+
def variables_introduced(self) -> int:
429+
return 0
430+
431+
@property
432+
def relation_ops_created(self) -> int:
433+
# Doesnt directly create any relational operations
434+
return 0
435+
293436

294437
@dataclass(frozen=True)
295438
class ProjectionNode(UnaryNode):
@@ -315,6 +458,12 @@ def schema(self) -> schemata.ArraySchema:
315458
)
316459
return schemata.ArraySchema(items)
317460

461+
@property
462+
def variables_introduced(self) -> int:
463+
# ignore passthrough expressions
464+
new_vars = sum(1 for i in self.assignments if not i[0].is_identity)
465+
return new_vars
466+
318467

319468
# TODO: Merge RowCount into Aggregate Node?
320469
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -334,6 +483,10 @@ def schema(self) -> schemata.ArraySchema:
334483
(schemata.SchemaItem("count", bigframes.dtypes.INT_DTYPE),)
335484
)
336485

486+
@property
487+
def variables_introduced(self) -> int:
488+
return 1
489+
337490

338491
@dataclass(frozen=True)
339492
class AggregateNode(UnaryNode):
@@ -367,6 +520,10 @@ def schema(self) -> schemata.ArraySchema:
367520
)
368521
return schemata.ArraySchema(tuple([*by_items, *agg_items]))
369522

523+
@property
524+
def variables_introduced(self) -> int:
525+
return len(self.aggregations) + len(self.by_column_ids)
526+
370527

371528
@dataclass(frozen=True)
372529
class WindowOpNode(UnaryNode):
@@ -396,12 +553,31 @@ def schema(self) -> schemata.ArraySchema:
396553
schemata.SchemaItem(self.output_name, new_item_dtype)
397554
)
398555

556+
@property
557+
def variables_introduced(self) -> int:
558+
return 1
559+
560+
@property
561+
def relation_ops_created(self) -> int:
562+
# Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
563+
return 0 if self.skip_reproject_unsafe else 4
564+
399565

566+
# TODO: Remove this op
400567
@dataclass(frozen=True)
401568
class ReprojectOpNode(UnaryNode):
402569
def __hash__(self):
403570
return self._node_hash
404571

572+
@property
573+
def variables_introduced(self) -> int:
574+
return 0
575+
576+
@property
577+
def relation_ops_created(self) -> int:
578+
# This op is not a real transformation, just a hint to the sql generator
579+
return 0
580+
405581

406582
@dataclass(frozen=True)
407583
class UnpivotNode(UnaryNode):
@@ -428,6 +604,10 @@ def row_preserving(self) -> bool:
428604
def non_local(self) -> bool:
429605
return True
430606

607+
@property
608+
def joins(self) -> bool:
609+
return True
610+
431611
@functools.cached_property
432612
def schema(self) -> schemata.ArraySchema:
433613
def infer_dtype(
@@ -469,6 +649,17 @@ def infer_dtype(
469649
]
470650
return schemata.ArraySchema((*index_items, *value_items, *passthrough_items))
471651

652+
@property
653+
def variables_introduced(self) -> int:
654+
return (
655+
len(self.schema.items) - len(self.passthrough_columns) + OVERHEAD_VARIABLES
656+
)
657+
658+
@property
659+
def relation_ops_created(self) -> int:
660+
# Unpivot is essentially a cross join and a projection.
661+
return 2
662+
472663

473664
@dataclass(frozen=True)
474665
class RandomSampleNode(UnaryNode):
@@ -485,6 +676,10 @@ def row_preserving(self) -> bool:
485676
def __hash__(self):
486677
return self._node_hash
487678

679+
@property
680+
def variables_introduced(self) -> int:
681+
return 1
682+
488683

489684
@dataclass(frozen=True)
490685
class ExplodeNode(UnaryNode):
@@ -511,3 +706,11 @@ def schema(self) -> schemata.ArraySchema:
511706
for name in self.child.schema.names
512707
)
513708
return schemata.ArraySchema(items)
709+
710+
@property
711+
def relation_ops_created(self) -> int:
712+
return 3
713+
714+
@functools.cached_property
715+
def variables_introduced(self) -> int:
716+
return len(self.column_ids) + 1

0 commit comments

Comments
 (0)