Skip to content

Commit fb62d40

Browse files
perf: multi-query execution of complex dataframes
1 parent 815f578 commit fb62d40

File tree

4 files changed

+193
-2
lines changed

4 files changed

+193
-2
lines changed

bigframes/core/nodes.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
from __future__ import annotations
1616

17-
from dataclasses import dataclass, field, fields
17+
from dataclasses import dataclass, field, fields, replace
1818
import functools
1919
import itertools
2020
import typing
21-
from typing import Tuple
21+
from typing import Callable, Tuple
2222

2323
import pandas
2424

@@ -100,6 +100,16 @@ def roots(self) -> typing.Set[BigFrameNode]:
100100
)
101101
return set(roots)
102102

103+
@functools.cached_property
104+
def complexity(self) -> int:
105+
"""A crude measure of the query complexity. Not necessarily predictive of the complexity of the actual computation."""
106+
return 0
107+
108+
def transform_children(
109+
self, t: Callable[[BigFrameNode], BigFrameNode]
110+
) -> BigFrameNode:
111+
return self
112+
103113

104114
@dataclass(frozen=True)
105115
class UnaryNode(BigFrameNode):
@@ -109,6 +119,11 @@ class UnaryNode(BigFrameNode):
109119
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
110120
return (self.child,)
111121

122+
def transform_children(
123+
self, t: Callable[[BigFrameNode], BigFrameNode]
124+
) -> BigFrameNode:
125+
return replace(self, child=t(self.child))
126+
112127

113128
@dataclass(frozen=True)
114129
class JoinNode(BigFrameNode):
@@ -138,6 +153,18 @@ def peekable(self) -> bool:
138153
single_root = len(self.roots) == 1
139154
return children_peekable and single_root
140155

156+
@functools.cached_property
157+
def complexity(self) -> int:
158+
child_complexity = sum(child.complexity for child in self.child_nodes)
159+
return child_complexity * 2
160+
161+
def transform_children(
162+
self, t: Callable[[BigFrameNode], BigFrameNode]
163+
) -> BigFrameNode:
164+
return replace(
165+
self, left_child=t(self.left_child), right_child=t(self.right_child)
166+
)
167+
141168

142169
@dataclass(frozen=True)
143170
class ConcatNode(BigFrameNode):
@@ -150,6 +177,16 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
150177
def __hash__(self):
151178
return self._node_hash
152179

180+
@functools.cached_property
181+
def complexity(self) -> int:
182+
child_complexity = sum(child.complexity for child in self.child_nodes)
183+
return child_complexity * 2
184+
185+
def transform_children(
186+
self, t: Callable[[BigFrameNode], BigFrameNode]
187+
) -> BigFrameNode:
188+
return replace(self, children=tuple(t(child) for child in self.children))
189+
153190

154191
# Input Nodex
155192
@dataclass(frozen=True)
@@ -167,6 +204,11 @@ def peekable(self) -> bool:
167204
def roots(self) -> typing.Set[BigFrameNode]:
168205
return {self}
169206

207+
@functools.cached_property
208+
def complexity(self) -> int:
209+
# TODO: Set to number of columns once this is more readily available
210+
return 500
211+
170212

171213
# TODO: Refactor to take raw gbq object reference
172214
@dataclass(frozen=True)
@@ -192,6 +234,10 @@ def peekable(self) -> bool:
192234
def roots(self) -> typing.Set[BigFrameNode]:
193235
return {self}
194236

237+
@functools.cached_property
238+
def complexity(self) -> int:
239+
return len(self.columns)
240+
195241

196242
# Unary nodes
197243
@dataclass(frozen=True)
@@ -209,6 +255,10 @@ def peekable(self) -> bool:
209255
def non_local(self) -> bool:
210256
return False
211257

258+
@functools.cached_property
259+
def complexity(self) -> int:
260+
return self.child.complexity * 2
261+
212262

213263
@dataclass(frozen=True)
214264
class FilterNode(UnaryNode):
@@ -221,6 +271,10 @@ def row_preserving(self) -> bool:
221271
def __hash__(self):
222272
return self._node_hash
223273

274+
@functools.cached_property
275+
def complexity(self) -> int:
276+
return self.child.complexity
277+
224278

225279
@dataclass(frozen=True)
226280
class OrderByNode(UnaryNode):
@@ -229,6 +283,10 @@ class OrderByNode(UnaryNode):
229283
def __hash__(self):
230284
return self._node_hash
231285

286+
@functools.cached_property
287+
def complexity(self) -> int:
288+
return self.child.complexity
289+
232290

233291
@dataclass(frozen=True)
234292
class ReversedNode(UnaryNode):
@@ -238,6 +296,10 @@ class ReversedNode(UnaryNode):
238296
def __hash__(self):
239297
return self._node_hash
240298

299+
@functools.cached_property
300+
def complexity(self) -> int:
301+
return self.child.complexity
302+
241303

242304
@dataclass(frozen=True)
243305
class ProjectionNode(UnaryNode):
@@ -246,6 +308,10 @@ class ProjectionNode(UnaryNode):
246308
def __hash__(self):
247309
return self._node_hash
248310

311+
@functools.cached_property
312+
def complexity(self) -> int:
313+
return self.child.complexity
314+
249315

250316
# TODO: Merge RowCount into Aggregate Node?
251317
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -259,6 +325,10 @@ def row_preserving(self) -> bool:
259325
def non_local(self) -> bool:
260326
return True
261327

328+
@functools.cached_property
329+
def complexity(self) -> int:
330+
return self.child.complexity
331+
262332

263333
@dataclass(frozen=True)
264334
class AggregateNode(UnaryNode):
@@ -281,6 +351,10 @@ def peekable(self) -> bool:
281351
def non_local(self) -> bool:
282352
return True
283353

354+
@functools.cached_property
355+
def complexity(self) -> int:
356+
return self.child.complexity * 2
357+
284358

285359
@dataclass(frozen=True)
286360
class WindowOpNode(UnaryNode):
@@ -302,12 +376,22 @@ def peekable(self) -> bool:
302376
def non_local(self) -> bool:
303377
return True
304378

379+
@functools.cached_property
380+
def complexity(self) -> int:
381+
if self.skip_reproject_unsafe:
382+
return self.child.complexity
383+
return self.child.complexity * 2
384+
305385

306386
@dataclass(frozen=True)
307387
class ReprojectOpNode(UnaryNode):
308388
def __hash__(self):
309389
return self._node_hash
310390

391+
@functools.cached_property
392+
def complexity(self) -> int:
393+
return self.child.complexity * 2
394+
311395

312396
@dataclass(frozen=True)
313397
class UnpivotNode(UnaryNode):
@@ -337,6 +421,10 @@ def non_local(self) -> bool:
337421
def peekable(self) -> bool:
338422
return False
339423

424+
@functools.cached_property
425+
def complexity(self) -> int:
426+
return self.child.complexity * 2
427+
340428

341429
@dataclass(frozen=True)
342430
class RandomSampleNode(UnaryNode):
@@ -350,5 +438,9 @@ def deterministic(self) -> bool:
350438
def row_preserving(self) -> bool:
351439
return False
352440

441+
@functools.cached_property
442+
def complexity(self) -> int:
443+
return self.child.complexity
444+
353445
def __hash__(self):
354446
return self._node_hash

bigframes/core/traversal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
1517
import bigframes.core.nodes as nodes
1618

1719

20+
@functools.cache
1821
def is_trivially_executable(node: nodes.BigFrameNode) -> bool:
1922
if local_only(node):
2023
return True

bigframes/session/__init__.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
import bigframes.core.blocks as blocks
7171
import bigframes.core.compile
7272
import bigframes.core.guid as guid
73+
import bigframes.core.nodes as nodes
7374
from bigframes.core.ordering import IntegerEncoding, OrderingColumnReference
7475
import bigframes.core.ordering as orderings
7576
import bigframes.core.traversal as traversals
@@ -112,6 +113,15 @@
112113
# TODO(tbergeron): Convert to bytes-based limit
113114
MAX_INLINE_DF_SIZE = 5000
114115

116+
# Chosen to prevent very large expression trees from being directly executed as a single query
117+
# Beyond this limit, expression should be cached or decomposed into multiple queries for execution.
118+
# May need tuning
119+
COMPLEXITY_SOFT_LIMIT = 50000
120+
# Beyond the complexity hard limit, will not even attempt to execute, as would require breaking into too many smaller queries.
121+
COMPLEXITY_HARD_LIMIT = COMPLEXITY_SOFT_LIMIT * 10
122+
# Limits how much bigframes will attempt to decompose complex queries into smaller queries
123+
MAX_DECOMPOSITION_STEPS = 1
124+
115125
logger = logging.getLogger(__name__)
116126

117127

@@ -1610,6 +1620,82 @@ def _cache_with_offsets(self, array_value: core.ArrayValue) -> core.ArrayValue:
16101620
ordering=orderings.ExpressionOrdering.from_offset_col("bigframes_offsets"),
16111621
)
16121622

1623+
def _cache_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
1624+
"""If the computaiton is complex, execute subtrees of the compuation."""
1625+
node = node
1626+
# Recurse into subtrees below complexity limit, then cache each of those
1627+
if node.complexity > COMPLEXITY_SOFT_LIMIT:
1628+
reduced = node.transform_children(lambda n: self._cache_subtrees(n))
1629+
return reduced
1630+
else:
1631+
# TODO: Add clustering columns based on access patterns
1632+
return self._cache_with_cluster_cols(core.ArrayValue(node), []).node
1633+
1634+
def _cache_repeated_subtrees(
1635+
self, node: nodes.BigFrameNode, max_iterations: int
1636+
) -> nodes.BigFrameNode:
1637+
"""If the computaiton is complex, execute subtrees of the compuation."""
1638+
root = node
1639+
# Identify node that maximizes complexity * (repetitions - 1)
1640+
# Recurse into subtrees below complexity limit, then cache each of those
1641+
1642+
for _ in range(max_iterations):
1643+
node_counts: Dict[nodes.BigFrameNode, int] = dict()
1644+
1645+
candidates = list(root.child_nodes)
1646+
while candidates:
1647+
candidate = candidates.pop(0)
1648+
# Don't bother with low complexity nodes
1649+
if candidate.complexity > COMPLEXITY_SOFT_LIMIT / 50:
1650+
if candidate.complexity < COMPLEXITY_SOFT_LIMIT:
1651+
node_count = node_counts.get(candidate, 0) + 1
1652+
node_counts[candidate] = node_count
1653+
candidates.extend(candidate.child_nodes)
1654+
1655+
# Only consider nodes the occur at least twice
1656+
valid_candidates = filter(lambda x: x[1] >= 2, node_counts.items())
1657+
# Heuristic: Complexity * (copies of subtree - 1), in other words, how much is complexity
1658+
# reduced by deduplicating
1659+
best_candidate = max(
1660+
valid_candidates, key=lambda i: i[0].complexity * i[1] - 1, default=None
1661+
)
1662+
1663+
if best_candidate is None:
1664+
# No good subtrees to cache, just return original tree
1665+
return root
1666+
1667+
node_to_replace = best_candidate[0]
1668+
1669+
# TODO: Add clustering columns based on access patterns
1670+
cached_node = self._cache_with_cluster_cols(
1671+
core.ArrayValue(best_candidate[0]), []
1672+
).node
1673+
1674+
def apply_substition(n: nodes.BigFrameNode) -> nodes.BigFrameNode:
1675+
if n == node_to_replace:
1676+
return cached_node
1677+
else:
1678+
return n.transform_children(apply_substition)
1679+
1680+
root = root.transform_children(apply_substition)
1681+
return root
1682+
1683+
def _simplify_with_caching(self, array_value: core.ArrayValue) -> core.ArrayValue:
1684+
"""Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces."""
1685+
node = array_value.node
1686+
if node.complexity < COMPLEXITY_SOFT_LIMIT:
1687+
return array_value
1688+
node = self._cache_repeated_subtrees(node, max_iterations=10)
1689+
if node.complexity > COMPLEXITY_HARD_LIMIT:
1690+
raise ValueError("This dataframe is too complex to convert to SQL queries.")
1691+
# TODO: add config flag to enable/disable recursive execution
1692+
for _ in range(MAX_DECOMPOSITION_STEPS):
1693+
if node.complexity > COMPLEXITY_SOFT_LIMIT:
1694+
node = self._cache_subtrees(node)
1695+
else:
1696+
return core.ArrayValue(node)
1697+
return core.ArrayValue(node)
1698+
16131699
def _is_trivially_executable(self, array_value: core.ArrayValue):
16141700
"""
16151701
Can the block be evaluated very cheaply?
@@ -1628,6 +1714,8 @@ def _execute(
16281714
dry_run=False,
16291715
col_id_overrides: Mapping[str, str] = {},
16301716
) -> tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
1717+
array_value = self._simplify_with_caching(array_value)
1718+
16311719
sql = self._to_sql(
16321720
array_value, sorted=sorted, col_id_overrides=col_id_overrides
16331721
) # type:ignore

tests/system/small/test_dataframe.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import bigframes
2828
import bigframes._config.display_options as display_options
2929
import bigframes.dataframe as dataframe
30+
import bigframes.pandas
3031
import bigframes.series as series
3132
from tests.system.utils import (
3233
assert_pandas_df_equal,
@@ -3882,6 +3883,13 @@ def test_recursion_limit(scalars_df_index):
38823883
scalars_df_index.to_pandas()
38833884

38843885

3886+
def test_query_complexity(scalars_df_index):
3887+
# Recursively union the data
3888+
for _ in range(3):
3889+
scalars_df_index = bigframes.pandas.concat(10 * [scalars_df_index]).head(5)
3890+
scalars_df_index.to_pandas()
3891+
3892+
38853893
def test_to_pandas_downsampling_option_override(session):
38863894
df = session.read_gbq("bigframes-dev.bigframes_tests_sys.batting")
38873895
download_size = 1

0 commit comments

Comments
 (0)