14
14
15
15
from __future__ import annotations
16
16
17
- from dataclasses import dataclass , field , fields
17
+ from dataclasses import dataclass , field , fields , replace
18
18
import functools
19
19
import itertools
20
20
import typing
21
- from typing import Tuple
21
+ from typing import Callable , Tuple
22
22
23
23
import pandas
24
24
@@ -100,6 +100,16 @@ def roots(self) -> typing.Set[BigFrameNode]:
100
100
)
101
101
return set (roots )
102
102
103
+ @functools .cached_property
104
+ def complexity (self ) -> int :
105
+ """A crude measure of the query complexity. Not necessarily predictive of the complexity of the actual computation."""
106
+ return 0
107
+
108
+ def transform_children (
109
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
110
+ ) -> BigFrameNode :
111
+ return self
112
+
103
113
104
114
@dataclass (frozen = True )
105
115
class UnaryNode (BigFrameNode ):
@@ -109,6 +119,11 @@ class UnaryNode(BigFrameNode):
109
119
def child_nodes (self ) -> typing .Sequence [BigFrameNode ]:
110
120
return (self .child ,)
111
121
122
+ def transform_children (
123
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
124
+ ) -> BigFrameNode :
125
+ return replace (self , child = t (self .child ))
126
+
112
127
113
128
@dataclass (frozen = True )
114
129
class JoinNode (BigFrameNode ):
@@ -138,6 +153,18 @@ def peekable(self) -> bool:
138
153
single_root = len (self .roots ) == 1
139
154
return children_peekable and single_root
140
155
156
+ @functools .cached_property
157
+ def complexity (self ) -> int :
158
+ child_complexity = sum (child .complexity for child in self .child_nodes )
159
+ return child_complexity * 2
160
+
161
+ def transform_children (
162
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
163
+ ) -> BigFrameNode :
164
+ return replace (
165
+ self , left_child = t (self .left_child ), right_child = t (self .right_child )
166
+ )
167
+
141
168
142
169
@dataclass (frozen = True )
143
170
class ConcatNode (BigFrameNode ):
@@ -150,6 +177,16 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
150
177
def __hash__ (self ):
151
178
return self ._node_hash
152
179
180
+ @functools .cached_property
181
+ def complexity (self ) -> int :
182
+ child_complexity = sum (child .complexity for child in self .child_nodes )
183
+ return child_complexity * 2
184
+
185
+ def transform_children (
186
+ self , t : Callable [[BigFrameNode ], BigFrameNode ]
187
+ ) -> BigFrameNode :
188
+ return replace (self , children = tuple (t (child ) for child in self .children ))
189
+
153
190
154
191
# Input Nodex
155
192
@dataclass (frozen = True )
@@ -167,6 +204,11 @@ def peekable(self) -> bool:
167
204
def roots (self ) -> typing .Set [BigFrameNode ]:
168
205
return {self }
169
206
207
+ @functools .cached_property
208
+ def complexity (self ) -> int :
209
+ # TODO: Set to number of columns once this is more readily available
210
+ return 500
211
+
170
212
171
213
# TODO: Refactor to take raw gbq object reference
172
214
@dataclass (frozen = True )
@@ -192,6 +234,10 @@ def peekable(self) -> bool:
192
234
def roots (self ) -> typing .Set [BigFrameNode ]:
193
235
return {self }
194
236
237
+ @functools .cached_property
238
+ def complexity (self ) -> int :
239
+ return len (self .columns )
240
+
195
241
196
242
# Unary nodes
197
243
@dataclass (frozen = True )
@@ -209,6 +255,10 @@ def peekable(self) -> bool:
209
255
def non_local (self ) -> bool :
210
256
return False
211
257
258
+ @functools .cached_property
259
+ def complexity (self ) -> int :
260
+ return self .child .complexity * 2
261
+
212
262
213
263
@dataclass (frozen = True )
214
264
class FilterNode (UnaryNode ):
@@ -221,6 +271,10 @@ def row_preserving(self) -> bool:
221
271
def __hash__ (self ):
222
272
return self ._node_hash
223
273
274
+ @functools .cached_property
275
+ def complexity (self ) -> int :
276
+ return self .child .complexity
277
+
224
278
225
279
@dataclass (frozen = True )
226
280
class OrderByNode (UnaryNode ):
@@ -229,6 +283,10 @@ class OrderByNode(UnaryNode):
229
283
def __hash__ (self ):
230
284
return self ._node_hash
231
285
286
+ @functools .cached_property
287
+ def complexity (self ) -> int :
288
+ return self .child .complexity
289
+
232
290
233
291
@dataclass (frozen = True )
234
292
class ReversedNode (UnaryNode ):
@@ -238,6 +296,10 @@ class ReversedNode(UnaryNode):
238
296
def __hash__ (self ):
239
297
return self ._node_hash
240
298
299
+ @functools .cached_property
300
+ def complexity (self ) -> int :
301
+ return self .child .complexity
302
+
241
303
242
304
@dataclass (frozen = True )
243
305
class ProjectionNode (UnaryNode ):
@@ -246,6 +308,10 @@ class ProjectionNode(UnaryNode):
246
308
def __hash__ (self ):
247
309
return self ._node_hash
248
310
311
+ @functools .cached_property
312
+ def complexity (self ) -> int :
313
+ return self .child .complexity
314
+
249
315
250
316
# TODO: Merge RowCount into Aggregate Node?
251
317
# Row count can be compute from table metadata sometimes, so it is a bit special.
@@ -259,6 +325,10 @@ def row_preserving(self) -> bool:
259
325
def non_local (self ) -> bool :
260
326
return True
261
327
328
+ @functools .cached_property
329
+ def complexity (self ) -> int :
330
+ return self .child .complexity
331
+
262
332
263
333
@dataclass (frozen = True )
264
334
class AggregateNode (UnaryNode ):
@@ -281,6 +351,10 @@ def peekable(self) -> bool:
281
351
def non_local (self ) -> bool :
282
352
return True
283
353
354
+ @functools .cached_property
355
+ def complexity (self ) -> int :
356
+ return self .child .complexity * 2
357
+
284
358
285
359
@dataclass (frozen = True )
286
360
class WindowOpNode (UnaryNode ):
@@ -302,12 +376,22 @@ def peekable(self) -> bool:
302
376
def non_local (self ) -> bool :
303
377
return True
304
378
379
+ @functools .cached_property
380
+ def complexity (self ) -> int :
381
+ if self .skip_reproject_unsafe :
382
+ return self .child .complexity
383
+ return self .child .complexity * 2
384
+
305
385
306
386
@dataclass (frozen = True )
307
387
class ReprojectOpNode (UnaryNode ):
308
388
def __hash__ (self ):
309
389
return self ._node_hash
310
390
391
+ @functools .cached_property
392
+ def complexity (self ) -> int :
393
+ return self .child .complexity * 2
394
+
311
395
312
396
@dataclass (frozen = True )
313
397
class UnpivotNode (UnaryNode ):
@@ -337,6 +421,10 @@ def non_local(self) -> bool:
337
421
def peekable (self ) -> bool :
338
422
return False
339
423
424
+ @functools .cached_property
425
+ def complexity (self ) -> int :
426
+ return self .child .complexity * 2
427
+
340
428
341
429
@dataclass (frozen = True )
342
430
class RandomSampleNode (UnaryNode ):
@@ -350,5 +438,9 @@ def deterministic(self) -> bool:
350
438
def row_preserving (self ) -> bool :
351
439
return False
352
440
441
+ @functools .cached_property
442
+ def complexity (self ) -> int :
443
+ return self .child .complexity
444
+
353
445
def __hash__ (self ):
354
446
return self ._node_hash
0 commit comments