Skip to content

Commit a364d6d

Browse files
refactor: define planning_complexity tree property
1 parent 56cefff commit a364d6d

File tree

2 files changed

+132
-0
lines changed

2 files changed

+132
-0
lines changed

bigframes/core/expression.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
108108
def is_bijective(self) -> bool:
109109
return False
110110

111+
@property
112+
def is_raw_variable(self) -> bool:
113+
return False
114+
111115

112116
@dataclasses.dataclass(frozen=True)
113117
class ScalarConstantExpression(Expression):
@@ -173,6 +177,10 @@ def bind_all_variables(self, bindings: Mapping[str, Expression]) -> Expression:
173177
def is_bijective(self) -> bool:
174178
return True
175179

180+
@property
181+
def is_raw_variable(self) -> bool:
182+
return True
183+
176184

177185
@dataclasses.dataclass(frozen=True)
178186
class OpExpression(Expression):

bigframes/core/nodes.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
import bigframes.core.ordering as orderings
3939
import bigframes.session
4040

41+
# A fixed number of variable to assume for overhead on some operations
42+
OVERHEAD_VARIABLES = 5
43+
4144

4245
@dataclass(frozen=True)
4346
class BigFrameNode:
@@ -107,6 +110,38 @@ def roots(self) -> typing.Set[BigFrameNode]:
107110
def schema(self) -> schemata.ArraySchema:
108111
...
109112

113+
@property
114+
@abc.abstractmethod
115+
def variables_introduced(self) -> int:
116+
"""
117+
Defines the number of variables generated by the current node. Used to estimate query planning complexity.
118+
"""
119+
...
120+
121+
@property
122+
def relation_ops_created(self) -> int:
123+
"""
124+
Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
125+
"""
126+
return 1
127+
128+
@functools.cached_property
129+
def total_variables(self) -> int:
130+
return self.variables_introduced + sum(
131+
map(lambda x: x.variables_introduced, self.child_nodes)
132+
)
133+
134+
@functools.cached_property
135+
def total_relational_ops(self) -> int:
136+
return self.relation_ops_created + sum(
137+
map(lambda x: x.relation_ops_created, self.child_nodes)
138+
)
139+
140+
@property
141+
def planning_complexity(self) -> int:
142+
"""Heuristic measure of planning complexity. Used to determine when to decompose overly complex computations."""
143+
return self.total_variables * self.total_relational_ops
144+
110145

111146
@dataclass(frozen=True)
112147
class UnaryNode(BigFrameNode):
@@ -165,6 +200,10 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
165200
)
166201
return schemata.ArraySchema(items)
167202

203+
@functools.cached_property
204+
def variables_introduced(self) -> int:
205+
return OVERHEAD_VARIABLES
206+
168207

169208
@dataclass(frozen=True)
170209
class ConcatNode(BigFrameNode):
@@ -193,6 +232,11 @@ def schema(self) -> schemata.ArraySchema:
193232
)
194233
return schemata.ArraySchema(items)
195234

235+
@functools.cached_property
236+
def variables_introduced(self) -> int:
237+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
238+
return OVERHEAD_VARIABLES
239+
196240

197241
# Input Nodex
198242
@dataclass(frozen=True)
@@ -216,6 +260,11 @@ def roots(self) -> typing.Set[BigFrameNode]:
216260
def schema(self) -> schemata.ArraySchema:
217261
return self.data_schema
218262

263+
@functools.cached_property
264+
def variables_introduced(self) -> int:
265+
"""Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
266+
return len(self.schema.items) + 1
267+
219268

220269
# TODO: Refactor to take raw gbq object reference
221270
@dataclass(frozen=True)
@@ -252,6 +301,10 @@ def schema(self) -> schemata.ArraySchema:
252301
)
253302
return schemata.ArraySchema(items)
254303

304+
@functools.cached_property
305+
def variables_introduced(self) -> int:
306+
return len(self.columns) + len(self.hidden_ordering_columns)
307+
255308

256309
# Unary nodes
257310
@dataclass(frozen=True)
@@ -275,6 +328,10 @@ def schema(self) -> schemata.ArraySchema:
275328
schemata.SchemaItem(self.col_id, bigframes.dtypes.INT_DTYPE)
276329
)
277330

331+
@functools.cached_property
332+
def variables_introduced(self) -> int:
333+
return 1
334+
278335

279336
@dataclass(frozen=True)
280337
class FilterNode(UnaryNode):
@@ -287,6 +344,10 @@ def row_preserving(self) -> bool:
287344
def __hash__(self):
288345
return self._node_hash
289346

347+
@property
348+
def variables_introduced(self) -> int:
349+
return 1
350+
290351

291352
@dataclass(frozen=True)
292353
class OrderByNode(UnaryNode):
@@ -304,6 +365,15 @@ def __post_init__(self):
304365
def __hash__(self):
305366
return self._node_hash
306367

368+
@property
369+
def variables_introduced(self) -> int:
370+
return 0
371+
372+
@property
373+
def relation_ops_created(self) -> int:
374+
# Doesnt directly create any relational operations
375+
return 0
376+
307377

308378
@dataclass(frozen=True)
309379
class ReversedNode(UnaryNode):
@@ -313,6 +383,15 @@ class ReversedNode(UnaryNode):
313383
def __hash__(self):
314384
return self._node_hash
315385

386+
@property
387+
def variables_introduced(self) -> int:
388+
return 0
389+
390+
@property
391+
def relation_ops_created(self) -> int:
392+
# Doesnt directly create any relational operations
393+
return 0
394+
316395

317396
@dataclass(frozen=True)
318397
class ProjectionNode(UnaryNode):
@@ -332,6 +411,12 @@ def schema(self) -> schemata.ArraySchema:
332411
)
333412
return schemata.ArraySchema(items)
334413

414+
@property
415+
def variables_introduced(self) -> int:
416+
# ignore passthrough expressions
417+
new_vars = sum(1 for i in self.assignments if not i[0].is_raw_variable)
418+
return new_vars
419+
335420

336421
# TODO: Merge RowCount into Aggregate Node?
337422
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -351,6 +436,11 @@ def schema(self) -> schemata.ArraySchema:
351436
(schemata.SchemaItem("count", bigframes.dtypes.INT_DTYPE),)
352437
)
353438

439+
@property
440+
def variables_introduced(self) -> int:
441+
# ignore passthrough expressions
442+
return 1
443+
354444

355445
@dataclass(frozen=True)
356446
class AggregateNode(UnaryNode):
@@ -388,6 +478,10 @@ def schema(self) -> schemata.ArraySchema:
388478
)
389479
return schemata.ArraySchema(tuple([*by_items, *agg_items]))
390480

481+
@property
482+
def variables_introduced(self) -> int:
483+
return len(self.aggregations) + len(self.by_column_ids)
484+
391485

392486
@dataclass(frozen=True)
393487
class WindowOpNode(UnaryNode):
@@ -421,12 +515,25 @@ def schema(self) -> schemata.ArraySchema:
421515
schemata.SchemaItem(self.output_name, new_item_dtype)
422516
)
423517

518+
@property
519+
def variables_introduced(self) -> int:
520+
return 1
521+
424522

425523
@dataclass(frozen=True)
426524
class ReprojectOpNode(UnaryNode):
427525
def __hash__(self):
428526
return self._node_hash
429527

528+
@property
529+
def variables_introduced(self) -> int:
530+
return 0
531+
532+
@property
533+
def relation_ops_created(self) -> int:
534+
# This op is not a real transformation, just a hint to the sql generator
535+
return 0
536+
430537

431538
@dataclass(frozen=True)
432539
class UnpivotNode(UnaryNode):
@@ -498,6 +605,19 @@ def infer_dtype(
498605
]
499606
return schemata.ArraySchema((*index_items, *value_items, *passthrough_items))
500607

608+
@property
609+
def variables_introduced(self) -> int:
610+
return (
611+
len(self.schema.items) - len(self.passthrough_columns) + OVERHEAD_VARIABLES
612+
)
613+
614+
@property
615+
def relation_ops_created(self) -> int:
616+
"""
617+
Unpivot is essentially a cross join and a projection.
618+
"""
619+
return 2
620+
501621

502622
@dataclass(frozen=True)
503623
class RandomSampleNode(UnaryNode):
@@ -513,3 +633,7 @@ def row_preserving(self) -> bool:
513633

514634
def __hash__(self):
515635
return self._node_hash
636+
637+
@property
638+
def variables_introduced(self) -> int:
639+
return 1

0 commit comments

Comments
 (0)