Skip to content

Commit cac75d5

Browse files
perf: multi-query execution of complex dataframes
1 parent 60594f4 commit cac75d5

File tree

4 files changed

+235
-2
lines changed

4 files changed

+235
-2
lines changed

bigframes/core/nodes.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
from dataclasses import dataclass, field, fields
17+
import abc
18+
from dataclasses import dataclass, field, fields, replace
1819
import functools
1920
import itertools
2021
import typing
21-
from typing import Tuple
22+
from typing import Callable, Tuple
2223

2324
import pandas
2425

@@ -100,6 +101,19 @@ def roots(self) -> typing.Set[BigFrameNode]:
100101
)
101102
return set(roots)
102103

104+
@functools.cached_property
105+
@abc.abstractmethod
106+
def complexity(self) -> int:
107+
"""A crude measure of the query complexity. Not necessarily predictive of the complexity of the actual computation."""
108+
...
109+
110+
@abc.abstractmethod
111+
def transform_children(
112+
self, t: Callable[[BigFrameNode], BigFrameNode]
113+
) -> BigFrameNode:
114+
"""Apply a function to each child node."""
115+
...
116+
103117

104118
@dataclass(frozen=True)
105119
class UnaryNode(BigFrameNode):
@@ -109,6 +123,15 @@ class UnaryNode(BigFrameNode):
109123
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
110124
return (self.child,)
111125

126+
def transform_children(
127+
self, t: Callable[[BigFrameNode], BigFrameNode]
128+
) -> BigFrameNode:
129+
return replace(self, child=t(self.child))
130+
131+
@functools.cached_property
132+
def complexity(self) -> int:
133+
return self.child.complexity + 1
134+
112135

113136
@dataclass(frozen=True)
114137
class JoinNode(BigFrameNode):
@@ -138,6 +161,18 @@ def peekable(self) -> bool:
138161
single_root = len(self.roots) == 1
139162
return children_peekable and single_root
140163

164+
@functools.cached_property
165+
def complexity(self) -> int:
166+
child_complexity = sum(child.complexity for child in self.child_nodes)
167+
return child_complexity * 2
168+
169+
def transform_children(
170+
self, t: Callable[[BigFrameNode], BigFrameNode]
171+
) -> BigFrameNode:
172+
return replace(
173+
self, left_child=t(self.left_child), right_child=t(self.right_child)
174+
)
175+
141176

142177
@dataclass(frozen=True)
143178
class ConcatNode(BigFrameNode):
@@ -150,6 +185,16 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
150185
def __hash__(self):
151186
return self._node_hash
152187

188+
@functools.cached_property
189+
def complexity(self) -> int:
190+
child_complexity = sum(child.complexity for child in self.child_nodes)
191+
return child_complexity * 2
192+
193+
def transform_children(
194+
self, t: Callable[[BigFrameNode], BigFrameNode]
195+
) -> BigFrameNode:
196+
return replace(self, children=tuple(t(child) for child in self.children))
197+
153198

154199
# Input Nodex
155200
@dataclass(frozen=True)
@@ -167,6 +212,16 @@ def peekable(self) -> bool:
167212
def roots(self) -> typing.Set[BigFrameNode]:
168213
return {self}
169214

215+
@functools.cached_property
216+
def complexity(self) -> int:
217+
# TODO: Set to number of columns once this is more readily available
218+
return 500
219+
220+
def transform_children(
221+
self, t: Callable[[BigFrameNode], BigFrameNode]
222+
) -> BigFrameNode:
223+
return self
224+
170225

171226
# TODO: Refactor to take raw gbq object reference
172227
@dataclass(frozen=True)
@@ -192,6 +247,15 @@ def peekable(self) -> bool:
192247
def roots(self) -> typing.Set[BigFrameNode]:
193248
return {self}
194249

250+
@functools.cached_property
251+
def complexity(self) -> int:
252+
return len(self.columns)
253+
254+
def transform_children(
255+
self, t: Callable[[BigFrameNode], BigFrameNode]
256+
) -> BigFrameNode:
257+
return self
258+
195259

196260
# Unary nodes
197261
@dataclass(frozen=True)
@@ -209,6 +273,10 @@ def peekable(self) -> bool:
209273
def non_local(self) -> bool:
210274
return False
211275

276+
@functools.cached_property
277+
def complexity(self) -> int:
278+
return self.child.complexity * 2
279+
212280

213281
@dataclass(frozen=True)
214282
class FilterNode(UnaryNode):
@@ -221,6 +289,10 @@ def row_preserving(self) -> bool:
221289
def __hash__(self):
222290
return self._node_hash
223291

292+
@functools.cached_property
293+
def complexity(self) -> int:
294+
return self.child.complexity + 1
295+
224296

225297
@dataclass(frozen=True)
226298
class OrderByNode(UnaryNode):
@@ -229,6 +301,10 @@ class OrderByNode(UnaryNode):
229301
def __hash__(self):
230302
return self._node_hash
231303

304+
@functools.cached_property
305+
def complexity(self) -> int:
306+
return self.child.complexity + 1
307+
232308

233309
@dataclass(frozen=True)
234310
class ReversedNode(UnaryNode):
@@ -238,6 +314,10 @@ class ReversedNode(UnaryNode):
238314
def __hash__(self):
239315
return self._node_hash
240316

317+
@functools.cached_property
318+
def complexity(self) -> int:
319+
return self.child.complexity + 1
320+
241321

242322
@dataclass(frozen=True)
243323
class ProjectionNode(UnaryNode):
@@ -246,6 +326,10 @@ class ProjectionNode(UnaryNode):
246326
def __hash__(self):
247327
return self._node_hash
248328

329+
@functools.cached_property
330+
def complexity(self) -> int:
331+
return self.child.complexity + 1
332+
249333

250334
# TODO: Merge RowCount into Aggregate Node?
251335
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -259,6 +343,10 @@ def row_preserving(self) -> bool:
259343
def non_local(self) -> bool:
260344
return True
261345

346+
@functools.cached_property
347+
def complexity(self) -> int:
348+
return self.child.complexity + 1
349+
262350

263351
@dataclass(frozen=True)
264352
class AggregateNode(UnaryNode):
@@ -281,6 +369,10 @@ def peekable(self) -> bool:
281369
def non_local(self) -> bool:
282370
return True
283371

372+
@functools.cached_property
373+
def complexity(self) -> int:
374+
return self.child.complexity * 2
375+
284376

285377
@dataclass(frozen=True)
286378
class WindowOpNode(UnaryNode):
@@ -302,12 +394,22 @@ def peekable(self) -> bool:
302394
def non_local(self) -> bool:
303395
return True
304396

397+
@functools.cached_property
398+
def complexity(self) -> int:
399+
if self.skip_reproject_unsafe:
400+
return self.child.complexity
401+
return self.child.complexity * 2
402+
305403

306404
@dataclass(frozen=True)
307405
class ReprojectOpNode(UnaryNode):
308406
def __hash__(self):
309407
return self._node_hash
310408

409+
@functools.cached_property
410+
def complexity(self) -> int:
411+
return self.child.complexity * 2
412+
311413

312414
@dataclass(frozen=True)
313415
class UnpivotNode(UnaryNode):
@@ -337,6 +439,10 @@ def non_local(self) -> bool:
337439
def peekable(self) -> bool:
338440
return False
339441

442+
@functools.cached_property
443+
def complexity(self) -> int:
444+
return self.child.complexity * 2
445+
340446

341447
@dataclass(frozen=True)
342448
class RandomSampleNode(UnaryNode):
@@ -350,5 +456,9 @@ def deterministic(self) -> bool:
350456
def row_preserving(self) -> bool:
351457
return False
352458

459+
@functools.cached_property
460+
def complexity(self) -> int:
461+
return self.child.complexity + 1
462+
353463
def __hash__(self):
354464
return self._node_hash

bigframes/core/traversal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
1517
import bigframes.core.nodes as nodes
1618

1719

20+
@functools.cache
1821
def is_trivially_executable(node: nodes.BigFrameNode) -> bool:
1922
if local_only(node):
2023
return True

bigframes/session/__init__.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import datetime
20+
import functools
2021
import itertools
2122
import logging
2223
import os
@@ -69,6 +70,7 @@
6970
import bigframes.core.blocks as blocks
7071
import bigframes.core.compile
7172
import bigframes.core.guid as guid
73+
import bigframes.core.nodes as nodes
7274
from bigframes.core.ordering import IntegerEncoding, OrderingColumnReference
7375
import bigframes.core.ordering as orderings
7476
import bigframes.core.traversal as traversals
@@ -110,6 +112,16 @@
110112
# TODO(tbergeron): Convert to bytes-based limit
111113
MAX_INLINE_DF_SIZE = 5000
112114

115+
# Chosen to prevent very large expression trees from being directly executed as a single query
116+
# Beyond this limit, expression should be cached or decomposed into multiple queries for execution.
117+
COMPLEXITY_SOFT_LIMIT = 10**5
118+
# Beyond the hard limite, even decomposing the query is unlikely to succeed
119+
COMPLEXITY_HARD_LIMIT = 10**25
120+
# Number of times to factor out and cache a repeated subtree
121+
MAX_SUBTREE_FACTORINGS = 10
122+
# Limits how much bigframes will attempt to decompose complex queries into smaller queries
123+
MAX_DECOMPOSITION_STEPS = 5
124+
113125
logger = logging.getLogger(__name__)
114126

115127

@@ -1671,6 +1683,87 @@ def _cache_with_offsets(self, array_value: core.ArrayValue) -> core.ArrayValue:
16711683
ordering=orderings.ExpressionOrdering.from_offset_col("bigframes_offsets"),
16721684
)
16731685

1686+
def _cache_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
1687+
"""If the computaiton is complex, execute subtrees of the compuation."""
1688+
node = node
1689+
# Recurse into subtrees below complexity limit, then cache each of those
1690+
if node.complexity > COMPLEXITY_SOFT_LIMIT:
1691+
reduced = node.transform_children(lambda n: self._cache_subtrees(n))
1692+
return reduced
1693+
else:
1694+
# TODO: Add clustering columns based on access patterns
1695+
return self._cache_with_cluster_cols(core.ArrayValue(node), []).node
1696+
1697+
def _cache_repeated_subtrees(
1698+
self, node: nodes.BigFrameNode, max_iterations: int
1699+
) -> nodes.BigFrameNode:
1700+
"""If the computaiton is complex, execute subtrees of the compuation."""
1701+
root = node
1702+
# Identify node that maximizes complexity * (repetitions - 1)
1703+
# Recurse into subtrees below complexity limit, then cache each of those
1704+
1705+
for _ in range(max_iterations):
1706+
node_counts: Dict[nodes.BigFrameNode, int] = dict()
1707+
1708+
candidates = list(root.child_nodes)
1709+
while candidates:
1710+
candidate = candidates.pop(0)
1711+
# Don't bother with low complexity nodes
1712+
if candidate.complexity > COMPLEXITY_SOFT_LIMIT / 50:
1713+
if candidate.complexity < COMPLEXITY_SOFT_LIMIT:
1714+
node_count = node_counts.get(candidate, 0) + 1
1715+
node_counts[candidate] = node_count
1716+
candidates.extend(candidate.child_nodes)
1717+
1718+
# Only consider nodes the occur at least twice
1719+
valid_candidates = filter(lambda x: x[1] >= 2, node_counts.items())
1720+
# Heuristic: Complexity * (copies of subtree - 1), in other words, how much is complexity
1721+
# reduced by deduplicating
1722+
best_candidate = max(
1723+
valid_candidates, key=lambda i: i[0].complexity * i[1] - 1, default=None
1724+
)
1725+
1726+
if best_candidate is None:
1727+
# No good subtrees to cache, just return original tree
1728+
return root
1729+
1730+
node_to_replace = best_candidate[0]
1731+
1732+
# TODO: Add clustering columns based on access patterns
1733+
cached_node = self._cache_with_cluster_cols(
1734+
core.ArrayValue(best_candidate[0]), []
1735+
).node
1736+
1737+
@functools.cache
1738+
def apply_substition(n: nodes.BigFrameNode) -> nodes.BigFrameNode:
1739+
if n == node_to_replace:
1740+
return cached_node
1741+
else:
1742+
return n.transform_children(apply_substition)
1743+
1744+
root = root.transform_children(apply_substition)
1745+
return root
1746+
1747+
def _simplify_with_caching(self, array_value: core.ArrayValue) -> core.ArrayValue:
1748+
"""Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces."""
1749+
node = array_value.node
1750+
if node.complexity < COMPLEXITY_SOFT_LIMIT:
1751+
return array_value
1752+
node = self._cache_repeated_subtrees(
1753+
node, max_iterations=MAX_SUBTREE_FACTORINGS
1754+
)
1755+
if node.complexity > COMPLEXITY_HARD_LIMIT:
1756+
raise ValueError(
1757+
f"This dataframe is too complex to convert to SQL queries. Internal complexity measure: {node.complexity}"
1758+
)
1759+
# TODO: add config flag to enable/disable recursive execution
1760+
for _ in range(MAX_DECOMPOSITION_STEPS):
1761+
if node.complexity > COMPLEXITY_SOFT_LIMIT:
1762+
node = self._cache_subtrees(node)
1763+
else:
1764+
return core.ArrayValue(node)
1765+
return core.ArrayValue(node)
1766+
16741767
def _is_trivially_executable(self, array_value: core.ArrayValue):
16751768
"""
16761769
Can the block be evaluated very cheaply?
@@ -1689,6 +1782,8 @@ def _execute(
16891782
dry_run=False,
16901783
col_id_overrides: Mapping[str, str] = {},
16911784
) -> tuple[bigquery.table.RowIterator, bigquery.QueryJob]:
1785+
array_value = self._simplify_with_caching(array_value)
1786+
16921787
sql = self._to_sql(
16931788
array_value, sorted=sorted, col_id_overrides=col_id_overrides
16941789
) # type:ignore

0 commit comments

Comments
 (0)