14
14
15
15
from __future__ import annotations
16
16
17
- from dataclasses import dataclass , field , fields
17
+ import abc
18
+ from dataclasses import dataclass , field , fields , replace
18
19
import functools
19
20
import itertools
20
21
import typing
21
- from typing import Tuple
22
+ from typing import Callable , Tuple
22
23
23
24
import pandas
24
25
@@ -100,6 +101,19 @@ def roots(self) -> typing.Set[BigFrameNode]:
100
101
)
101
102
return set (roots )
102
103
104
+ @functools .cached_property
105
+ @abc .abstractmethod
106
+ def complexity (self ) -> int :
107
+ """A crude measure of the query complexity. Not necessarily predictive of the complexity of the actual computation."""
108
+ ...
109
+
110
+ @abc .abstractmethod
111
+ def transform_children (
112
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
113
+ ) -> BigFrameNode :
114
+ """Apply a function to each child node."""
115
+ ...
116
+
103
117
104
118
@dataclass (frozen = True )
105
119
class UnaryNode (BigFrameNode ):
@@ -109,6 +123,15 @@ class UnaryNode(BigFrameNode):
109
123
def child_nodes (self ) -> typing .Sequence [BigFrameNode ]:
110
124
return (self .child ,)
111
125
126
+ def transform_children (
127
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
128
+ ) -> BigFrameNode :
129
+ return replace (self , child = t (self .child ))
130
+
131
+ @functools .cached_property
132
+ def complexity (self ) -> int :
133
+ return self .child .complexity + 1
134
+
112
135
113
136
@dataclass (frozen = True )
114
137
class JoinNode (BigFrameNode ):
@@ -138,6 +161,18 @@ def peekable(self) -> bool:
138
161
single_root = len (self .roots ) == 1
139
162
return children_peekable and single_root
140
163
164
+ @functools .cached_property
165
+ def complexity (self ) -> int :
166
+ child_complexity = sum (child .complexity for child in self .child_nodes )
167
+ return child_complexity * 2
168
+
169
+ def transform_children (
170
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
171
+ ) -> BigFrameNode :
172
+ return replace (
173
+ self , left_child = t (self .left_child ), right_child = t (self .right_child )
174
+ )
175
+
141
176
142
177
@dataclass (frozen = True )
143
178
class ConcatNode (BigFrameNode ):
@@ -150,6 +185,16 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
150
185
def __hash__ (self ):
151
186
return self ._node_hash
152
187
188
+ @functools .cached_property
189
+ def complexity (self ) -> int :
190
+ child_complexity = sum (child .complexity for child in self .child_nodes )
191
+ return child_complexity * 2
192
+
193
+ def transform_children (
194
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
195
+ ) -> BigFrameNode :
196
+ return replace (self , children = tuple (t (child ) for child in self .children ))
197
+
153
198
154
199
# Input Nodex
155
200
@dataclass (frozen = True )
@@ -167,6 +212,16 @@ def peekable(self) -> bool:
167
212
def roots (self ) -> typing .Set [BigFrameNode ]:
168
213
return {self }
169
214
215
+ @functools .cached_property
216
+ def complexity (self ) -> int :
217
+ # TODO: Set to number of columns once this is more readily available
218
+ return 500
219
+
220
+ def transform_children (
221
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
222
+ ) -> BigFrameNode :
223
+ return self
224
+
170
225
171
226
# TODO: Refactor to take raw gbq object reference
172
227
@dataclass (frozen = True )
@@ -192,6 +247,15 @@ def peekable(self) -> bool:
192
247
def roots (self ) -> typing .Set [BigFrameNode ]:
193
248
return {self }
194
249
250
+ @functools .cached_property
251
+ def complexity (self ) -> int :
252
+ return len (self .columns ) + 5
253
+
254
+ def transform_children (
255
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
256
+ ) -> BigFrameNode :
257
+ return self
258
+
195
259
196
260
# Unary nodes
197
261
@dataclass (frozen = True )
@@ -209,6 +273,10 @@ def peekable(self) -> bool:
209
273
def non_local (self ) -> bool :
210
274
return False
211
275
276
+ @functools .cached_property
277
+ def complexity (self ) -> int :
278
+ return self .child .complexity * 2
279
+
212
280
213
281
@dataclass (frozen = True )
214
282
class FilterNode (UnaryNode ):
@@ -221,6 +289,10 @@ def row_preserving(self) -> bool:
221
289
def __hash__ (self ):
222
290
return self ._node_hash
223
291
292
+ @functools .cached_property
293
+ def complexity (self ) -> int :
294
+ return self .child .complexity + 1
295
+
224
296
225
297
@dataclass (frozen = True )
226
298
class OrderByNode (UnaryNode ):
@@ -229,6 +301,10 @@ class OrderByNode(UnaryNode):
229
301
def __hash__ (self ):
230
302
return self ._node_hash
231
303
304
+ @functools .cached_property
305
+ def complexity (self ) -> int :
306
+ return self .child .complexity + 1
307
+
232
308
233
309
@dataclass (frozen = True )
234
310
class ReversedNode (UnaryNode ):
@@ -238,6 +314,10 @@ class ReversedNode(UnaryNode):
238
314
def __hash__ (self ):
239
315
return self ._node_hash
240
316
317
+ @functools .cached_property
318
+ def complexity (self ) -> int :
319
+ return self .child .complexity + 1
320
+
241
321
242
322
@dataclass (frozen = True )
243
323
class ProjectionNode (UnaryNode ):
@@ -246,6 +326,10 @@ class ProjectionNode(UnaryNode):
246
326
def __hash__ (self ):
247
327
return self ._node_hash
248
328
329
+ @functools .cached_property
330
+ def complexity (self ) -> int :
331
+ return self .child .complexity + 1
332
+
249
333
250
334
# TODO: Merge RowCount into Aggregate Node?
251
335
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -259,6 +343,10 @@ def row_preserving(self) -> bool:
259
343
def non_local (self ) -> bool :
260
344
return True
261
345
346
+ @functools .cached_property
347
+ def complexity (self ) -> int :
348
+ return self .child .complexity + 1
349
+
262
350
263
351
@dataclass (frozen = True )
264
352
class AggregateNode (UnaryNode ):
@@ -281,6 +369,10 @@ def peekable(self) -> bool:
281
369
def non_local (self ) -> bool :
282
370
return True
283
371
372
+ @functools .cached_property
373
+ def complexity (self ) -> int :
374
+ return self .child .complexity * 2
375
+
284
376
285
377
@dataclass (frozen = True )
286
378
class WindowOpNode (UnaryNode ):
@@ -302,12 +394,22 @@ def peekable(self) -> bool:
302
394
def non_local (self ) -> bool :
303
395
return True
304
396
397
+ @functools .cached_property
398
+ def complexity (self ) -> int :
399
+ if self .skip_reproject_unsafe :
400
+ return self .child .complexity
401
+ return self .child .complexity * 2
402
+
305
403
306
404
@dataclass (frozen = True )
307
405
class ReprojectOpNode (UnaryNode ):
308
406
def __hash__ (self ):
309
407
return self ._node_hash
310
408
409
+ @functools .cached_property
410
+ def complexity (self ) -> int :
411
+ return self .child .complexity * 2
412
+
311
413
312
414
@dataclass (frozen = True )
313
415
class UnpivotNode (UnaryNode ):
@@ -337,6 +439,10 @@ def non_local(self) -> bool:
337
439
def peekable (self ) -> bool :
338
440
return False
339
441
442
+ @functools .cached_property
443
+ def complexity (self ) -> int :
444
+ return self .child .complexity * 2
445
+
340
446
341
447
@dataclass (frozen = True )
342
448
class RandomSampleNode (UnaryNode ):
@@ -350,5 +456,9 @@ def deterministic(self) -> bool:
350
456
def row_preserving (self ) -> bool :
351
457
return False
352
458
459
+ @functools .cached_property
460
+ def complexity (self ) -> int :
461
+ return self .child .complexity + 1
462
+
353
463
def __hash__ (self ):
354
464
return self ._node_hash
0 commit comments