15
15
from __future__ import annotations
16
16
17
17
import abc
18
- from dataclasses import dataclass , field , fields
18
+ from dataclasses import dataclass , field , fields , replace
19
19
import functools
20
20
import itertools
21
21
import typing
22
- from typing import Tuple
22
+ from typing import Callable , Tuple
23
23
24
24
import pandas
25
25
39
39
import bigframes .session
40
40
41
41
42
+ # A fixed number of variable to assume for overhead on some operations
43
+ OVERHEAD_VARIABLES = 5
44
+
45
+
42
46
@dataclass (frozen = True )
43
47
class BigFrameNode :
44
48
"""
@@ -102,6 +106,60 @@ def roots(self) -> typing.Set[BigFrameNode]:
102
106
def schema (self ) -> schemata .ArraySchema :
103
107
...
104
108
109
+ @property
110
+ @abc .abstractmethod
111
+ def variables_introduced (self ) -> int :
112
+ """
113
+ Defines number of values created by the current node. Helps represent the "width" of a query
114
+ """
115
+ ...
116
+
117
+ @property
118
+ def relation_ops_created (self ) -> int :
119
+ """
120
+ Defines the number of relational ops generated by the current node. Used to estimate query planning complexity.
121
+ """
122
+ return 1
123
+
124
+ @property
125
+ def joins (self ) -> bool :
126
+ """
127
+ Defines whether the node joins data.
128
+ """
129
+ return False
130
+
131
+ @functools .cached_property
132
+ def total_variables (self ) -> int :
133
+ return self .variables_introduced + sum (
134
+ map (lambda x : x .total_variables , self .child_nodes )
135
+ )
136
+
137
+ @functools .cached_property
138
+ def total_relational_ops (self ) -> int :
139
+ return self .relation_ops_created + sum (
140
+ map (lambda x : x .total_relational_ops , self .child_nodes )
141
+ )
142
+
143
+ @functools .cached_property
144
+ def total_joins (self ) -> int :
145
+ return int (self .joins ) + sum (map (lambda x : x .total_joins , self .child_nodes ))
146
+
147
+ @property
148
+ def planning_complexity (self ) -> int :
149
+ """
150
+ Empirical heuristic measure of planning complexity.
151
+
152
+ Used to determine when to decompose overly complex computations. May require tuning.
153
+ """
154
+ return self .total_variables * self .total_relational_ops * (1 + self .total_joins )
155
+
156
+ @abc .abstractmethod
157
+ def transform_children (
158
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
159
+ ) -> BigFrameNode :
160
+ """Apply a function to each child node."""
161
+ ...
162
+
105
163
106
164
@dataclass (frozen = True )
107
165
class UnaryNode (BigFrameNode ):
@@ -115,6 +173,11 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
115
173
def schema (self ) -> schemata .ArraySchema :
116
174
return self .child .schema
117
175
176
+ def transform_children (
177
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
178
+ ) -> BigFrameNode :
179
+ return replace (self , child = t (self .child ))
180
+
118
181
119
182
@dataclass (frozen = True )
120
183
class JoinNode (BigFrameNode ):
@@ -154,6 +217,22 @@ def join_mapping_to_schema_item(mapping: JoinColumnMapping):
154
217
)
155
218
return schemata .ArraySchema (items )
156
219
220
+ @functools .cached_property
221
+ def variables_introduced (self ) -> int :
222
+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
223
+ return OVERHEAD_VARIABLES
224
+
225
+ @property
226
+ def joins (self ) -> bool :
227
+ return True
228
+
229
+ def transform_children (
230
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
231
+ ) -> BigFrameNode :
232
+ return replace (
233
+ self , left_child = t (self .left_child ), right_child = t (self .right_child )
234
+ )
235
+
157
236
158
237
@dataclass (frozen = True )
159
238
class ConcatNode (BigFrameNode ):
@@ -182,6 +261,16 @@ def schema(self) -> schemata.ArraySchema:
182
261
)
183
262
return schemata .ArraySchema (items )
184
263
264
+ @functools .cached_property
265
+ def variables_introduced (self ) -> int :
266
+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
267
+ return len (self .schema .items ) + OVERHEAD_VARIABLES
268
+
269
+ def transform_children (
270
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
271
+ ) -> BigFrameNode :
272
+ return replace (self , children = tuple (t (child ) for child in self .children ))
273
+
185
274
186
275
# Input Nodex
187
276
@dataclass (frozen = True )
@@ -201,6 +290,16 @@ def roots(self) -> typing.Set[BigFrameNode]:
201
290
def schema (self ) -> schemata .ArraySchema :
202
291
return self .data_schema
203
292
293
+ @functools .cached_property
294
+ def variables_introduced (self ) -> int :
295
+ """Defines the number of variables generated by the current node. Used to estimate query planning complexity."""
296
+ return len (self .schema .items ) + 1
297
+
298
+ def transform_children (
299
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
300
+ ) -> BigFrameNode :
301
+ return self
302
+
204
303
205
304
# TODO: Refactor to take raw gbq object reference
206
305
@dataclass (frozen = True )
@@ -233,6 +332,20 @@ def schema(self) -> schemata.ArraySchema:
233
332
)
234
333
return schemata .ArraySchema (items )
235
334
335
+ @functools .cached_property
336
+ def variables_introduced (self ) -> int :
337
+ return len (self .columns ) + len (self .hidden_ordering_columns )
338
+
339
+ @property
340
+ def relation_ops_created (self ) -> int :
341
+ # Assume worst case, where readgbq actually has baked in analytic operation to generate index
342
+ return 2
343
+
344
+ def transform_children (
345
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
346
+ ) -> BigFrameNode :
347
+ return self
348
+
236
349
237
350
# Unary nodes
238
351
@dataclass (frozen = True )
@@ -252,6 +365,14 @@ def schema(self) -> schemata.ArraySchema:
252
365
schemata .SchemaItem (self .col_id , bigframes .dtypes .INT_DTYPE )
253
366
)
254
367
368
+ @property
369
+ def relation_ops_created (self ) -> int :
370
+ return 2
371
+
372
+ @functools .cached_property
373
+ def variables_introduced (self ) -> int :
374
+ return 1
375
+
255
376
256
377
@dataclass (frozen = True )
257
378
class FilterNode (UnaryNode ):
@@ -264,6 +385,10 @@ def row_preserving(self) -> bool:
264
385
def __hash__ (self ):
265
386
return self ._node_hash
266
387
388
+ @property
389
+ def variables_introduced (self ) -> int :
390
+ return 1
391
+
267
392
268
393
@dataclass (frozen = True )
269
394
class OrderByNode (UnaryNode ):
@@ -281,6 +406,15 @@ def __post_init__(self):
281
406
def __hash__ (self ):
282
407
return self ._node_hash
283
408
409
+ @property
410
+ def variables_introduced (self ) -> int :
411
+ return 0
412
+
413
+ @property
414
+ def relation_ops_created (self ) -> int :
415
+ # Doesnt directly create any relational operations
416
+ return 0
417
+
284
418
285
419
@dataclass (frozen = True )
286
420
class ReversedNode (UnaryNode ):
@@ -290,6 +424,15 @@ class ReversedNode(UnaryNode):
290
424
def __hash__ (self ):
291
425
return self ._node_hash
292
426
427
+ @property
428
+ def variables_introduced (self ) -> int :
429
+ return 0
430
+
431
+ @property
432
+ def relation_ops_created (self ) -> int :
433
+ # Doesnt directly create any relational operations
434
+ return 0
435
+
293
436
294
437
@dataclass (frozen = True )
295
438
class ProjectionNode (UnaryNode ):
@@ -315,6 +458,12 @@ def schema(self) -> schemata.ArraySchema:
315
458
)
316
459
return schemata .ArraySchema (items )
317
460
461
+ @property
462
+ def variables_introduced (self ) -> int :
463
+ # ignore passthrough expressions
464
+ new_vars = sum (1 for i in self .assignments if not i [0 ].is_identity )
465
+ return new_vars
466
+
318
467
319
468
# TODO: Merge RowCount into Aggregate Node?
320
469
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -334,6 +483,10 @@ def schema(self) -> schemata.ArraySchema:
334
483
(schemata .SchemaItem ("count" , bigframes .dtypes .INT_DTYPE ),)
335
484
)
336
485
486
+ @property
487
+ def variables_introduced (self ) -> int :
488
+ return 1
489
+
337
490
338
491
@dataclass (frozen = True )
339
492
class AggregateNode (UnaryNode ):
@@ -367,6 +520,10 @@ def schema(self) -> schemata.ArraySchema:
367
520
)
368
521
return schemata .ArraySchema (tuple ([* by_items , * agg_items ]))
369
522
523
+ @property
524
+ def variables_introduced (self ) -> int :
525
+ return len (self .aggregations ) + len (self .by_column_ids )
526
+
370
527
371
528
@dataclass (frozen = True )
372
529
class WindowOpNode (UnaryNode ):
@@ -396,12 +553,31 @@ def schema(self) -> schemata.ArraySchema:
396
553
schemata .SchemaItem (self .output_name , new_item_dtype )
397
554
)
398
555
556
+ @property
557
+ def variables_introduced (self ) -> int :
558
+ return 1
559
+
560
+ @property
561
+ def relation_ops_created (self ) -> int :
562
+ # Assume that if not reprojecting, that there is a sequence of window operations sharing the same window
563
+ return 0 if self .skip_reproject_unsafe else 4
564
+
399
565
566
+ # TODO: Remove this op
400
567
@dataclass (frozen = True )
401
568
class ReprojectOpNode (UnaryNode ):
402
569
def __hash__ (self ):
403
570
return self ._node_hash
404
571
572
+ @property
573
+ def variables_introduced (self ) -> int :
574
+ return 0
575
+
576
+ @property
577
+ def relation_ops_created (self ) -> int :
578
+ # This op is not a real transformation, just a hint to the sql generator
579
+ return 0
580
+
405
581
406
582
@dataclass (frozen = True )
407
583
class UnpivotNode (UnaryNode ):
@@ -428,6 +604,10 @@ def row_preserving(self) -> bool:
428
604
def non_local (self ) -> bool :
429
605
return True
430
606
607
+ @property
608
+ def joins (self ) -> bool :
609
+ return True
610
+
431
611
@functools .cached_property
432
612
def schema (self ) -> schemata .ArraySchema :
433
613
def infer_dtype (
@@ -469,6 +649,17 @@ def infer_dtype(
469
649
]
470
650
return schemata .ArraySchema ((* index_items , * value_items , * passthrough_items ))
471
651
652
+ @property
653
+ def variables_introduced (self ) -> int :
654
+ return (
655
+ len (self .schema .items ) - len (self .passthrough_columns ) + OVERHEAD_VARIABLES
656
+ )
657
+
658
+ @property
659
+ def relation_ops_created (self ) -> int :
660
+ # Unpivot is essentially a cross join and a projection.
661
+ return 2
662
+
472
663
473
664
@dataclass (frozen = True )
474
665
class RandomSampleNode (UnaryNode ):
@@ -485,6 +676,10 @@ def row_preserving(self) -> bool:
485
676
def __hash__ (self ):
486
677
return self ._node_hash
487
678
679
+ @property
680
+ def variables_introduced (self ) -> int :
681
+ return 1
682
+
488
683
489
684
@dataclass (frozen = True )
490
685
class ExplodeNode (UnaryNode ):
@@ -511,3 +706,11 @@ def schema(self) -> schemata.ArraySchema:
511
706
for name in self .child .schema .names
512
707
)
513
708
return schemata .ArraySchema (items )
709
+
710
+ @property
711
+ def relation_ops_created (self ) -> int :
712
+ return 3
713
+
714
+ @functools .cached_property
715
+ def variables_introduced (self ) -> int :
716
+ return len (self .column_ids ) + 1
0 commit comments