Skip to content

Commit 18e5115

Browse files
perf: multi-query execution of complex dataframes
1 parent 60594f4 commit 18e5115

File tree

7 files changed

+263
-2
lines changed

7 files changed

+263
-2
lines changed

bigframes/core/blocks.py

+4
Original file line numberDiff line numberDiff line change
@@ -1808,6 +1808,10 @@ def cached(self, *, optimize_offsets=False, force: bool = False) -> Block:
18081808
expr = self.session._cache_with_cluster_cols(
18091809
self.expr, cluster_cols=self.index_columns
18101810
)
1811+
return self.swap_array_expr(expr)
1812+
1813+
def swap_array_expr(self, expr: core.ArrayValue) -> Block:
1814+
# TODO: Validate schema unchanged
18111815
return Block(
18121816
expr,
18131817
index_columns=self.index_columns,

bigframes/core/nodes.py

+112-2
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
from dataclasses import dataclass, field, fields
17+
import abc
18+
from dataclasses import dataclass, field, fields, replace
1819
import functools
1920
import itertools
2021
import typing
21-
from typing import Tuple
22+
from typing import Callable, Tuple
2223

2324
import pandas
2425

@@ -100,6 +101,19 @@ def roots(self) -> typing.Set[BigFrameNode]:
100101
)
101102
return set(roots)
102103

104+
@functools.cached_property
105+
@abc.abstractmethod
106+
def complexity(self) -> int:
107+
"""A crude measure of the query complexity. Not necessarily predictive of the complexity of the actual computation."""
108+
...
109+
110+
@abc.abstractmethod
111+
def transform_children(
112+
self, t: Callable[[BigFrameNode], BigFrameNode]
113+
) -> BigFrameNode:
114+
"""Apply a function to each child node."""
115+
...
116+
103117

104118
@dataclass(frozen=True)
105119
class UnaryNode(BigFrameNode):
@@ -109,6 +123,15 @@ class UnaryNode(BigFrameNode):
109123
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
110124
return (self.child,)
111125

126+
def transform_children(
127+
self, t: Callable[[BigFrameNode], BigFrameNode]
128+
) -> BigFrameNode:
129+
return replace(self, child=t(self.child))
130+
131+
@functools.cached_property
132+
def complexity(self) -> int:
133+
return self.child.complexity + 1
134+
112135

113136
@dataclass(frozen=True)
114137
class JoinNode(BigFrameNode):
@@ -138,6 +161,18 @@ def peekable(self) -> bool:
138161
single_root = len(self.roots) == 1
139162
return children_peekable and single_root
140163

164+
@functools.cached_property
165+
def complexity(self) -> int:
166+
child_complexity = sum(child.complexity for child in self.child_nodes)
167+
return child_complexity * 2
168+
169+
def transform_children(
170+
self, t: Callable[[BigFrameNode], BigFrameNode]
171+
) -> BigFrameNode:
172+
return replace(
173+
self, left_child=t(self.left_child), right_child=t(self.right_child)
174+
)
175+
141176

142177
@dataclass(frozen=True)
143178
class ConcatNode(BigFrameNode):
@@ -150,6 +185,16 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
150185
def __hash__(self):
151186
return self._node_hash
152187

188+
@functools.cached_property
189+
def complexity(self) -> int:
190+
child_complexity = sum(child.complexity for child in self.child_nodes)
191+
return child_complexity * 2
192+
193+
def transform_children(
194+
self, t: Callable[[BigFrameNode], BigFrameNode]
195+
) -> BigFrameNode:
196+
return replace(self, children=tuple(t(child) for child in self.children))
197+
153198

154199
# Input Nodex
155200
@dataclass(frozen=True)
@@ -167,6 +212,16 @@ def peekable(self) -> bool:
167212
def roots(self) -> typing.Set[BigFrameNode]:
168213
return {self}
169214

215+
@functools.cached_property
216+
def complexity(self) -> int:
217+
# TODO: Set to number of columns once this is more readily available
218+
return 500
219+
220+
def transform_children(
221+
self, t: Callable[[BigFrameNode], BigFrameNode]
222+
) -> BigFrameNode:
223+
return self
224+
170225

171226
# TODO: Refactor to take raw gbq object reference
172227
@dataclass(frozen=True)
@@ -192,6 +247,15 @@ def peekable(self) -> bool:
192247
def roots(self) -> typing.Set[BigFrameNode]:
193248
return {self}
194249

250+
@functools.cached_property
251+
def complexity(self) -> int:
252+
return len(self.columns) + 5
253+
254+
def transform_children(
255+
self, t: Callable[[BigFrameNode], BigFrameNode]
256+
) -> BigFrameNode:
257+
return self
258+
195259

196260
# Unary nodes
197261
@dataclass(frozen=True)
@@ -209,6 +273,10 @@ def peekable(self) -> bool:
209273
def non_local(self) -> bool:
210274
return False
211275

276+
@functools.cached_property
277+
def complexity(self) -> int:
278+
return self.child.complexity * 2
279+
212280

213281
@dataclass(frozen=True)
214282
class FilterNode(UnaryNode):
@@ -221,6 +289,10 @@ def row_preserving(self) -> bool:
221289
def __hash__(self):
222290
return self._node_hash
223291

292+
@functools.cached_property
293+
def complexity(self) -> int:
294+
return self.child.complexity + 1
295+
224296

225297
@dataclass(frozen=True)
226298
class OrderByNode(UnaryNode):
@@ -229,6 +301,10 @@ class OrderByNode(UnaryNode):
229301
def __hash__(self):
230302
return self._node_hash
231303

304+
@functools.cached_property
305+
def complexity(self) -> int:
306+
return self.child.complexity + 1
307+
232308

233309
@dataclass(frozen=True)
234310
class ReversedNode(UnaryNode):
@@ -238,6 +314,10 @@ class ReversedNode(UnaryNode):
238314
def __hash__(self):
239315
return self._node_hash
240316

317+
@functools.cached_property
318+
def complexity(self) -> int:
319+
return self.child.complexity + 1
320+
241321

242322
@dataclass(frozen=True)
243323
class ProjectionNode(UnaryNode):
@@ -246,6 +326,10 @@ class ProjectionNode(UnaryNode):
246326
def __hash__(self):
247327
return self._node_hash
248328

329+
@functools.cached_property
330+
def complexity(self) -> int:
331+
return self.child.complexity + 1
332+
249333

250334
# TODO: Merge RowCount into Aggregate Node?
251335
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -259,6 +343,10 @@ def row_preserving(self) -> bool:
259343
def non_local(self) -> bool:
260344
return True
261345

346+
@functools.cached_property
347+
def complexity(self) -> int:
348+
return self.child.complexity + 1
349+
262350

263351
@dataclass(frozen=True)
264352
class AggregateNode(UnaryNode):
@@ -281,6 +369,10 @@ def peekable(self) -> bool:
281369
def non_local(self) -> bool:
282370
return True
283371

372+
@functools.cached_property
373+
def complexity(self) -> int:
374+
return self.child.complexity * 2
375+
284376

285377
@dataclass(frozen=True)
286378
class WindowOpNode(UnaryNode):
@@ -302,12 +394,22 @@ def peekable(self) -> bool:
302394
def non_local(self) -> bool:
303395
return True
304396

397+
@functools.cached_property
398+
def complexity(self) -> int:
399+
if self.skip_reproject_unsafe:
400+
return self.child.complexity
401+
return self.child.complexity * 2
402+
305403

306404
@dataclass(frozen=True)
307405
class ReprojectOpNode(UnaryNode):
308406
def __hash__(self):
309407
return self._node_hash
310408

409+
@functools.cached_property
410+
def complexity(self) -> int:
411+
return self.child.complexity * 2
412+
311413

312414
@dataclass(frozen=True)
313415
class UnpivotNode(UnaryNode):
@@ -337,6 +439,10 @@ def non_local(self) -> bool:
337439
def peekable(self) -> bool:
338440
return False
339441

442+
@functools.cached_property
443+
def complexity(self) -> int:
444+
return self.child.complexity * 2
445+
340446

341447
@dataclass(frozen=True)
342448
class RandomSampleNode(UnaryNode):
@@ -350,5 +456,9 @@ def deterministic(self) -> bool:
350456
def row_preserving(self) -> bool:
351457
return False
352458

459+
@functools.cached_property
460+
def complexity(self) -> int:
461+
return self.child.complexity + 1
462+
353463
def __hash__(self):
354464
return self._node_hash

bigframes/core/traversal.py

+3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
1517
import bigframes.core.nodes as nodes
1618

1719

20+
@functools.cache
1821
def is_trivially_executable(node: nodes.BigFrameNode) -> bool:
1922
if local_only(node):
2023
return True

bigframes/dataframe.py

+11
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,7 @@ def to_pandas(
10601060
downsampled rows and all columns of this DataFrame.
10611061
"""
10621062
# TODO(orrbradford): Optimize this in future. Potentially some cases where we can return the stored query job
1063+
self._optimize_query_complexity()
10631064
df, query_job = self._block.to_pandas(
10641065
max_download_size=max_download_size,
10651066
sampling_method=sampling_method,
@@ -1071,6 +1072,7 @@ def to_pandas(
10711072

10721073
def to_pandas_batches(self) -> Iterable[pandas.DataFrame]:
10731074
"""Stream DataFrame results to an iterable of pandas DataFrame"""
1075+
self._optimize_query_complexity()
10741076
return self._block.to_pandas_batches()
10751077

10761078
def _compute_dry_run(self) -> bigquery.QueryJob:
@@ -2973,6 +2975,7 @@ def _run_io_query(
29732975
"""Executes a query job presenting this dataframe and returns the destination
29742976
table."""
29752977
session = self._block.expr.session
2978+
self._optimize_query_complexity()
29762979
export_array, id_overrides = self._prepare_export(
29772980
index=index, ordering_id=ordering_id
29782981
)
@@ -3109,6 +3112,14 @@ def _cached(self, *, force: bool = False) -> DataFrame:
31093112
self._set_block(self._block.cached(force=force))
31103113
return self
31113114

3115+
def _optimize_query_complexity(self):
3116+
"""Reduce query complexity by caching repeated subtrees and recursively materializing maximum-complexity subtrees.
3117+
May generate many queries and take substantial time to execute.
3118+
"""
3119+
# TODO: Move all this to session
3120+
new_expr = self._session._simplify_with_caching(self._block.expr)
3121+
self._set_block(self._block.swap_array_expr(new_expr))
3122+
31123123
_DataFrameOrSeries = typing.TypeVar("_DataFrameOrSeries")
31133124

31143125
def dot(self, other: _DataFrameOrSeries) -> _DataFrameOrSeries:

bigframes/series.py

+10
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __len__(self):
160160
return self.shape[0]
161161

162162
def __iter__(self) -> typing.Iterator:
163+
self._optimize_query_complexity()
163164
return itertools.chain.from_iterable(
164165
map(lambda x: x.squeeze(axis=1), self._block.to_pandas_batches())
165166
)
@@ -312,6 +313,7 @@ def to_pandas(
312313
pandas.Series: A pandas Series with all rows of this Series if the data_sampling_threshold_mb
313314
is not exceeded; otherwise, a pandas Series with downsampled rows of the DataFrame.
314315
"""
316+
self._optimize_query_complexity()
315317
df, query_job = self._block.to_pandas(
316318
max_download_size=max_download_size,
317319
sampling_method=sampling_method,
@@ -1573,6 +1575,14 @@ def _cached(self, *, force: bool = True) -> Series:
15731575
self._set_block(self._block.cached(force=force))
15741576
return self
15751577

1578+
def _optimize_query_complexity(self):
1579+
"""Reduce query complexity by caching repeated subtrees and recursively materializing maximum-complexity subtrees.
1580+
May generate many queries and take substantial time to execute.
1581+
"""
1582+
# TODO: Move all this to session
1583+
new_expr = self._block.session._simplify_with_caching(self._block.expr)
1584+
self._set_block(self._block.swap_array_expr(new_expr))
1585+
15761586

15771587
def _is_list_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Sequence]:
15781588
return pandas.api.types.is_list_like(obj)

0 commit comments

Comments
 (0)