Skip to content

Commit 8e8031b

Browse files
refactor: define planning_complexity tree property
1 parent 56cefff commit 8e8031b

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

bigframes/core/expression.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
108108
def is_bijective(self) -> bool:
109109
return False
110110

111+
@property
112+
def is_raw_variable(self) -> bool:
113+
return False
114+
111115

112116
@dataclasses.dataclass(frozen=True)
113117
class ScalarConstantExpression(Expression):
@@ -173,6 +177,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
173177
def is_bijective(self) -> bool:
174178
return True
175179

180+
@property
181+
def is_raw_variable(self) -> bool:
182+
return True
183+
176184

177185
@dataclasses.dataclass(frozen=True)
178186
class OpExpression(Expression):

bigframes/core/nodes.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
import bigframes.core.ordering as orderings
3939
import bigframes.session
4040

41+
# A fixed number of variable to assume for overhead on some operations
42+
OVERHEAD_VARIABLES = 5
43+
4144

4245
@dataclass(frozen=True)
4346
class BigFrameNode:
@@ -107,6 +110,38 @@ def roots(self) -> typing.Set[BigFrameNode]:
107110
def schema(self) -> schemata.ArraySchema:
108111
...
109112

113+
@property
114+
@abc.abstractmethod
115+
def variables_introduced(self) -> int:
116+
"""
117+
Defines the number of variables generated by the current node. Used to estimate query planning complexity.
118+
"""
119+
...
120+
121+
@property
122+
def relation_ops_created(self) -> int:
123+
"""
124+
Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
125+
"""
126+
return 1
127+
128+
@functools.cached_property
129+
def total_variables(self) -> int:
130+
return self.variables_introduced + sum(
131+
map(lambda x: x.total_variables, self.child_nodes)
132+
)
133+
134+
@functools.cached_property
135+
def total_relational_ops(self) -> int:
136+
return self.relation_ops_created + sum(
137+
map(lambda x: x.total_relational_ops, self.child_nodes)
138+
)
139+
140+
@property
141+
def planning_complexity(self) -> int:
142+
"""Heuristic measure of planning complexity. Used to determine when to decompose overly complex computations."""
143+
return self.total_variables * self.total_relational_ops
144+
110145

111146
@dataclass(frozen=True)
112147
class UnaryNode(BigFrameNode):
@@ -165,6 +200,10 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
165200
)
166201
return schemata.ArraySchema(items)
167202

203+
@functools.cached_property
204+
def variables_introduced(self) -> int:
205+
return OVERHEAD_VARIABLES
206+
168207

169208
@dataclass(frozen=True)
170209
class ConcatNode(BigFrameNode):
@@ -193,6 +232,11 @@ def schema(self) -> schemata.ArraySchema:
193232
)
194233
return schemata.ArraySchema(items)
195234

235+
@functools.cached_property
236+
def variables_introduced(self) -> int:
237+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
238+
return OVERHEAD_VARIABLES
239+
196240

197241
# Input Nodex
198242
@dataclass(frozen=True)
@@ -216,6 +260,11 @@ def roots(self) -> typing.Set[BigFrameNode]:
216260
def schema(self) -> schemata.ArraySchema:
217261
return self.data_schema
218262

263+
@functools.cached_property
264+
def variables_introduced(self) -> int:
265+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
266+
return len(self.schema.items) + 1
267+
219268

220269
# TODO: Refactor to take raw gbq object reference
221270
@dataclass(frozen=True)
@@ -252,6 +301,15 @@ def schema(self) -> schemata.ArraySchema:
252301
)
253302
return schemata.ArraySchema(items)
254303

304+
@functools.cached_property
305+
def variables_introduced(self) -> int:
306+
return len(self.columns) + len(self.hidden_ordering_columns)
307+
308+
@property
309+
def relation_ops_created(self) -> int:
310+
# Assume worst case, where readgbq actually has baked in analytic operation to generate index
311+
return 2
312+
255313

256314
# Unary nodes
257315
@dataclass(frozen=True)
@@ -275,6 +333,10 @@ def schema(self) -> schemata.ArraySchema:
275333
schemata.SchemaItem(self.col_id, bigframes.dtypes.INT_DTYPE)
276334
)
277335

336+
@functools.cached_property
337+
def variables_introduced(self) -> int:
338+
return 1
339+
278340

279341
@dataclass(frozen=True)
280342
class FilterNode(UnaryNode):
@@ -287,6 +349,10 @@ def row_preserving(self) -> bool:
287349
def __hash__(self):
288350
return self._node_hash
289351

352+
@property
353+
def variables_introduced(self) -> int:
354+
return 1
355+
290356

291357
@dataclass(frozen=True)
292358
class OrderByNode(UnaryNode):
@@ -304,6 +370,15 @@ def __post_init__(self):
304370
def __hash__(self):
305371
return self._node_hash
306372

373+
@property
374+
def variables_introduced(self) -> int:
375+
return 0
376+
377+
@property
378+
def relation_ops_created(self) -> int:
379+
# Doesnt directly create any relational operations
380+
return 0
381+
307382

308383
@dataclass(frozen=True)
309384
class ReversedNode(UnaryNode):
@@ -313,6 +388,15 @@ class ReversedNode(UnaryNode):
313388
def __hash__(self):
314389
return self._node_hash
315390

391+
@property
392+
def variables_introduced(self) -> int:
393+
return 0
394+
395+
@property
396+
def relation_ops_created(self) -> int:
397+
# Doesnt directly create any relational operations
398+
return 0
399+
316400

317401
@dataclass(frozen=True)
318402
class ProjectionNode(UnaryNode):
@@ -332,6 +416,12 @@ def schema(self) -> schemata.ArraySchema:
332416
)
333417
return schemata.ArraySchema(items)
334418

419+
@property
420+
def variables_introduced(self) -> int:
421+
# ignore passthrough expressions
422+
new_vars = sum(1 for i in self.assignments if not i[0].is_raw_variable)
423+
return new_vars
424+
335425

336426
# TODO: Merge RowCount into Aggregate Node?
337427
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -351,6 +441,11 @@ def schema(self) -> schemata.ArraySchema:
351441
(schemata.SchemaItem("count", bigframes.dtypes.INT_DTYPE),)
352442
)
353443

444+
@property
445+
def variables_introduced(self) -> int:
446+
# ignore passthrough expressions
447+
return 1
448+
354449

355450
@dataclass(frozen=True)
356451
class AggregateNode(UnaryNode):
@@ -388,6 +483,10 @@ def schema(self) -> schemata.ArraySchema:
388483
)
389484
return schemata.ArraySchema(tuple([*by_items, *agg_items]))
390485

486+
@property
487+
def variables_introduced(self) -> int:
488+
return len(self.aggregations) + len(self.by_column_ids)
489+
391490

392491
@dataclass(frozen=True)
393492
class WindowOpNode(UnaryNode):
@@ -421,12 +520,31 @@ def schema(self) -> schemata.ArraySchema:
421520
schemata.SchemaItem(self.output_name, new_item_dtype)
422521
)
423522

523+
@property
524+
def variables_introduced(self) -> int:
525+
return 1
526+
527+
@property
528+
def relation_ops_created(self) -> int:
529+
# Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
530+
return 0 if self.skip_reproject_unsafe else 2
531+
424532

533+
# TODO: Remove this op
425534
@dataclass(frozen=True)
426535
class ReprojectOpNode(UnaryNode):
427536
def __hash__(self):
428537
return self._node_hash
429538

539+
@property
540+
def variables_introduced(self) -> int:
541+
return 0
542+
543+
@property
544+
def relation_ops_created(self) -> int:
545+
# This op is not a real transformation, just a hint to the sql generator
546+
return 0
547+
430548

431549
@dataclass(frozen=True)
432550
class UnpivotNode(UnaryNode):
@@ -498,6 +616,17 @@ def infer_dtype(
498616
]
499617
return schemata.ArraySchema((*index_items, *value_items, *passthrough_items))
500618

619+
@property
620+
def variables_introduced(self) -> int:
621+
return (
622+
len(self.schema.items) - len(self.passthrough_columns) + OVERHEAD_VARIABLES
623+
)
624+
625+
@property
626+
def relation_ops_created(self) -> int:
627+
# Unpivot is essentially a cross join and a projection.
628+
return 2
629+
501630

502631
@dataclass(frozen=True)
503632
class RandomSampleNode(UnaryNode):
@@ -513,3 +642,7 @@ def row_preserving(self) -> bool:
513642

514643
def __hash__(self):
515644
return self._node_hash
645+
646+
@property
647+
def variables_introduced(self) -> int:
648+
return 1

0 commit comments

Comments
 (0)