38
38
import bigframes .core .ordering as orderings
39
39
import bigframes .session
40
40
41
+ # A fixed number of variable to assume for overhead on some operations
42
+ OVERHEAD_VARIABLES = 5
43
+
41
44
42
45
@dataclass (frozen = True )
43
46
class BigFrameNode :
@@ -107,6 +110,38 @@ def roots(self) -> typing.Set[BigFrameNode]:
107
110
def schema (self ) -> schemata .ArraySchema :
108
111
...
109
112
113
+ @property
114
+ @abc .abstractmethod
115
+ def variables_introduced (self ) -> int :
116
+ """
117
+ Defines the number of variables generated by the current node. Used to estimate query planning complexity.
118
+ """
119
+ ...
120
+
121
+ @property
122
+ def relation_ops_created (self ) -> int :
123
+ """
124
+ Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
125
+ """
126
+ return 1
127
+
128
+ @functools .cached_property
129
+ def total_variables (self ) -> int :
130
+ return self .variables_introduced + sum (
131
+ map (lambda x : x .variables_introduced , self .child_nodes )
132
+ )
133
+
134
+ @functools .cached_property
135
+ def total_relational_ops (self ) -> int :
136
+ return self .relation_ops_created + sum (
137
+ map (lambda x : x .relation_ops_created , self .child_nodes )
138
+ )
139
+
140
+ @property
141
+ def planning_complexity (self ) -> int :
142
+ """Heuristic measure of planning complexity. Used to determine when to decompose overly complex computations."""
143
+ return self .total_variables * self .total_relational_ops
144
+
110
145
111
146
@dataclass (frozen = True )
112
147
class UnaryNode (BigFrameNode ):
@@ -165,6 +200,10 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
165
200
)
166
201
return schemata .ArraySchema (items )
167
202
203
+ @functools .cached_property
204
+ def variables_introduced (self ) -> int :
205
+ return OVERHEAD_VARIABLES
206
+
168
207
169
208
@dataclass (frozen = True )
170
209
class ConcatNode (BigFrameNode ):
@@ -193,6 +232,11 @@ def schema(self) -> schemata.ArraySchema:
193
232
)
194
233
return schemata .ArraySchema (items )
195
234
235
+ @functools .cached_property
236
+ def variables_introduced (self ) -> int :
237
+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
238
+ return OVERHEAD_VARIABLES
239
+
196
240
197
241
# Input Nodex
198
242
@dataclass (frozen = True )
@@ -216,6 +260,11 @@ def roots(self) -> typing.Set[BigFrameNode]:
216
260
def schema (self ) -> schemata .ArraySchema :
217
261
return self .data_schema
218
262
263
+ @functools .cached_property
264
+ def variables_introduced (self ) -> int :
265
+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
266
+ return len (self .schema .items ) + 1
267
+
219
268
220
269
# TODO: Refactor to take raw gbq object reference
221
270
@dataclass (frozen = True )
@@ -252,6 +301,10 @@ def schema(self) -> schemata.ArraySchema:
252
301
)
253
302
return schemata .ArraySchema (items )
254
303
304
+ @functools .cached_property
305
+ def variables_introduced (self ) -> int :
306
+ return len (self .columns ) + len (self .hidden_ordering_columns )
307
+
255
308
256
309
# Unary nodes
257
310
@dataclass (frozen = True )
@@ -275,6 +328,10 @@ def schema(self) -> schemata.ArraySchema:
275
328
schemata .SchemaItem (self .col_id , bigframes .dtypes .INT_DTYPE )
276
329
)
277
330
331
+ @functools .cached_property
332
+ def variables_introduced (self ) -> int :
333
+ return 1
334
+
278
335
279
336
@dataclass (frozen = True )
280
337
class FilterNode (UnaryNode ):
@@ -287,6 +344,10 @@ def row_preserving(self) -> bool:
287
344
def __hash__ (self ):
288
345
return self ._node_hash
289
346
347
+ @property
348
+ def variables_introduced (self ) -> int :
349
+ return 1
350
+
290
351
291
352
@dataclass (frozen = True )
292
353
class OrderByNode (UnaryNode ):
@@ -304,6 +365,15 @@ def __post_init__(self):
304
365
def __hash__ (self ):
305
366
return self ._node_hash
306
367
368
+ @property
369
+ def variables_introduced (self ) -> int :
370
+ return 0
371
+
372
+ @property
373
+ def relation_ops_created (self ) -> int :
374
+ # Doesnt directly create any relational operations
375
+ return 0
376
+
307
377
308
378
@dataclass (frozen = True )
309
379
class ReversedNode (UnaryNode ):
@@ -313,6 +383,15 @@ class ReversedNode(UnaryNode):
313
383
def __hash__ (self ):
314
384
return self ._node_hash
315
385
386
+ @property
387
+ def variables_introduced (self ) -> int :
388
+ return 0
389
+
390
+ @property
391
+ def relation_ops_created (self ) -> int :
392
+ # Doesnt directly create any relational operations
393
+ return 0
394
+
316
395
317
396
@dataclass (frozen = True )
318
397
class ProjectionNode (UnaryNode ):
@@ -332,6 +411,12 @@ def schema(self) -> schemata.ArraySchema:
332
411
)
333
412
return schemata .ArraySchema (items )
334
413
414
+ @property
415
+ def variables_introduced (self ) -> int :
416
+ # ignore passthrough expressions
417
+ new_vars = sum (1 for i in self .assignments if not i [0 ].is_raw_variable )
418
+ return new_vars
419
+
335
420
336
421
# TODO: Merge RowCount into Aggregate Node?
337
422
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -351,6 +436,11 @@ def schema(self) -> schemata.ArraySchema:
351
436
(schemata .SchemaItem ("count" , bigframes .dtypes .INT_DTYPE ),)
352
437
)
353
438
439
+ @property
440
+ def variables_introduced (self ) -> int :
441
+ # ignore passthrough expressions
442
+ return 1
443
+
354
444
355
445
@dataclass (frozen = True )
356
446
class AggregateNode (UnaryNode ):
@@ -388,6 +478,10 @@ def schema(self) -> schemata.ArraySchema:
388
478
)
389
479
return schemata .ArraySchema (tuple ([* by_items , * agg_items ]))
390
480
481
+ @property
482
+ def variables_introduced (self ) -> int :
483
+ return len (self .aggregations ) + len (self .by_column_ids )
484
+
391
485
392
486
@dataclass (frozen = True )
393
487
class WindowOpNode (UnaryNode ):
@@ -421,12 +515,25 @@ def schema(self) -> schemata.ArraySchema:
421
515
schemata .SchemaItem (self .output_name , new_item_dtype )
422
516
)
423
517
518
+ @property
519
+ def variables_introduced (self ) -> int :
520
+ return 1
521
+
424
522
425
523
@dataclass (frozen = True )
426
524
class ReprojectOpNode (UnaryNode ):
427
525
def __hash__ (self ):
428
526
return self ._node_hash
429
527
528
+ @property
529
+ def variables_introduced (self ) -> int :
530
+ return 0
531
+
532
+ @property
533
+ def relation_ops_created (self ) -> int :
534
+ # This op is not a real transformation, just a hint to the sql generator
535
+ return 0
536
+
430
537
431
538
@dataclass (frozen = True )
432
539
class UnpivotNode (UnaryNode ):
@@ -498,6 +605,19 @@ def infer_dtype(
498
605
]
499
606
return schemata .ArraySchema ((* index_items , * value_items , * passthrough_items ))
500
607
608
+ @property
609
+ def variables_introduced (self ) -> int :
610
+ return (
611
+ len (self .schema .items ) - len (self .passthrough_columns ) + OVERHEAD_VARIABLES
612
+ )
613
+
614
+ @property
615
+ def relation_ops_created (self ) -> int :
616
+ """
617
+ Unpivot is essentially a cross join and a projection.
618
+ """
619
+ return 2
620
+
501
621
502
622
@dataclass (frozen = True )
503
623
class RandomSampleNode (UnaryNode ):
@@ -513,3 +633,7 @@ def row_preserving(self) -> bool:
513
633
514
634
def __hash__ (self ):
515
635
return self ._node_hash
636
+
637
+ @property
638
+ def variables_introduced (self ) -> int :
639
+ return 1
0 commit comments