Skip to content

Commit 6e73d77

Browse files
refactor: Remove never_skip_nulls param from window node def (#2273)
1 parent 3e3fe25 commit 6e73d77

File tree

36 files changed

+88
-157
lines changed

36 files changed

+88
-157
lines changed

bigframes/core/array_value.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -401,37 +401,10 @@ def aggregate(
401401
)
402402
)
403403

404-
def project_window_op(
405-
self,
406-
column_name: str,
407-
op: agg_ops.UnaryWindowOp,
408-
window_spec: WindowSpec,
409-
*,
410-
never_skip_nulls=False,
411-
skip_reproject_unsafe: bool = False,
412-
) -> Tuple[ArrayValue, str]:
413-
"""
414-
Creates a new expression based on this expression with unary operation applied to one column.
415-
column_name: the id of the input column present in the expression
416-
op: the windowable operator to apply to the input column
417-
window_spec: a specification of the window over which to apply the operator
418-
output_name: the id to assign to the output of the operator, by default will replace input col if distinct output id not provided
419-
never_skip_nulls: will disable null skipping for operators that would otherwise do so
420-
skip_reproject_unsafe: skips the reprojection step, can be used when performing many non-dependent window operations, user responsible for not nesting window expressions, or using outputs as join, filter or aggregation keys before a reprojection
421-
"""
422-
423-
return self.project_window_expr(
424-
agg_expressions.UnaryAggregation(op, ex.deref(column_name)),
425-
window_spec,
426-
never_skip_nulls,
427-
skip_reproject_unsafe,
428-
)
429-
430404
def project_window_expr(
431405
self,
432406
expression: agg_expressions.Aggregation,
433407
window: WindowSpec,
434-
never_skip_nulls=False,
435408
skip_reproject_unsafe: bool = False,
436409
):
437410
output_name = self._gen_namespaced_uid()
@@ -442,7 +415,6 @@ def project_window_expr(
442415
expression=expression,
443416
window_spec=window,
444417
output_name=ids.ColumnId(output_name),
445-
never_skip_nulls=never_skip_nulls,
446418
skip_reproject_unsafe=skip_reproject_unsafe,
447419
)
448420
),

bigframes/core/blocks.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,6 @@ def multi_apply_window_op(
10901090
window_spec: windows.WindowSpec,
10911091
*,
10921092
skip_null_groups: bool = False,
1093-
never_skip_nulls: bool = False,
10941093
) -> typing.Tuple[Block, typing.Sequence[str]]:
10951094
block = self
10961095
result_ids = []
@@ -1103,7 +1102,6 @@ def multi_apply_window_op(
11031102
skip_reproject_unsafe=(i + 1) < len(columns),
11041103
result_label=label,
11051104
skip_null_groups=skip_null_groups,
1106-
never_skip_nulls=never_skip_nulls,
11071105
)
11081106
result_ids.append(result_id)
11091107
return block, result_ids
@@ -1184,15 +1182,13 @@ def apply_window_op(
11841182
result_label: Label = None,
11851183
skip_null_groups: bool = False,
11861184
skip_reproject_unsafe: bool = False,
1187-
never_skip_nulls: bool = False,
11881185
) -> typing.Tuple[Block, str]:
11891186
agg_expr = agg_expressions.UnaryAggregation(op, ex.deref(column))
11901187
return self.apply_analytic(
11911188
agg_expr,
11921189
window_spec,
11931190
result_label,
11941191
skip_reproject_unsafe=skip_reproject_unsafe,
1195-
never_skip_nulls=never_skip_nulls,
11961192
skip_null_groups=skip_null_groups,
11971193
)
11981194

@@ -1203,7 +1199,6 @@ def apply_analytic(
12031199
result_label: Label,
12041200
*,
12051201
skip_reproject_unsafe: bool = False,
1206-
never_skip_nulls: bool = False,
12071202
skip_null_groups: bool = False,
12081203
) -> typing.Tuple[Block, str]:
12091204
block = self
@@ -1214,7 +1209,6 @@ def apply_analytic(
12141209
agg_expr,
12151210
window,
12161211
skip_reproject_unsafe=skip_reproject_unsafe,
1217-
never_skip_nulls=never_skip_nulls,
12181212
)
12191213
block = Block(
12201214
expr,

bigframes/core/compile/compiled.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,13 @@ def project_window_op(
394394
expression: ex_types.Aggregation,
395395
window_spec: WindowSpec,
396396
output_name: str,
397-
*,
398-
never_skip_nulls=False,
399397
) -> UnorderedIR:
400398
"""
401399
Creates a new expression based on this expression with unary operation applied to one column.
402400
column_name: the id of the input column present in the expression
403401
op: the windowable operator to apply to the input column
404402
window_spec: a specification of the window over which to apply the operator
405403
output_name: the id to assign to the output of the operator
406-
never_skip_nulls: will disable null skipping for operators that would otherwise do so
407404
"""
408405
# Cannot nest analytic expressions, so reproject to cte first if needed.
409406
# Also ibis cannot window literals, so need to reproject those (even though this is legal in googlesql)
@@ -425,7 +422,6 @@ def project_window_op(
425422
expression,
426423
window_spec,
427424
output_name,
428-
never_skip_nulls=never_skip_nulls,
429425
)
430426

431427
if expression.op.order_independent and window_spec.is_unbounded:
@@ -437,9 +433,6 @@ def project_window_op(
437433
expression, window_spec
438434
)
439435
clauses: list[tuple[ex.Expression, ex.Expression]] = []
440-
if expression.op.skips_nulls and not never_skip_nulls:
441-
for input in expression.inputs:
442-
clauses.append((ops.isnull_op.as_expr(input), ex.const(None)))
443436
if window_spec.min_periods and len(expression.inputs) > 0:
444437
if not expression.op.nulls_count_for_min_values:
445438
is_observation = ops.notnull_op.as_expr()

bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def compile_window(node: nodes.WindowOpNode, child: compiled.UnorderedIR):
269269
node.expression,
270270
node.window_spec,
271271
node.output_name.sql,
272-
never_skip_nulls=node.never_skip_nulls,
273272
)
274273
return result
275274

bigframes/core/compile/polars/compiler.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import dataclasses
1717
import functools
1818
import itertools
19-
import operator
2019
from typing import cast, Literal, Optional, Sequence, Tuple, Type, TYPE_CHECKING
2120

2221
import pandas as pd
@@ -868,26 +867,6 @@ def compile_window(self, node: nodes.WindowOpNode):
868867
df, node.expression, node.window_spec, node.output_name.sql
869868
)
870869
result = pl.concat([df, window_result], how="horizontal")
871-
872-
# Probably easier just to pull this out as a rewriter
873-
if (
874-
node.expression.op.skips_nulls
875-
and not node.never_skip_nulls
876-
and node.expression.column_references
877-
):
878-
nullity_expr = functools.reduce(
879-
operator.or_,
880-
(
881-
pl.col(column.sql).is_null()
882-
for column in node.expression.column_references
883-
),
884-
)
885-
result = result.with_columns(
886-
pl.when(nullity_expr)
887-
.then(None)
888-
.otherwise(pl.col(node.output_name.sql))
889-
.alias(node.output_name.sql)
890-
)
891870
return result
892871

893872
def _calc_row_analytic_func(

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,8 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
324324
)
325325

326326
clauses: list[tuple[sge.Expression, sge.Expression]] = []
327-
if node.expression.op.skips_nulls and not node.never_skip_nulls:
328-
for column in inputs:
329-
clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null()))
330-
331327
if window_spec.min_periods and len(inputs) > 0:
332-
if node.expression.op.skips_nulls:
328+
if not node.expression.op.nulls_count_for_min_values:
333329
# Most operations do not count NULL values towards min_periods
334330
not_null_columns = [
335331
sge.Not(this=sge.Is(this=column, expression=sge.Null()))

bigframes/core/groupby/dataframe_group_by.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import bigframes.core.window_spec as window_specs
3939
import bigframes.dataframe as df
4040
import bigframes.dtypes as dtypes
41+
import bigframes.operations
4142
import bigframes.operations.aggregations as agg_ops
4243
import bigframes.series as series
4344

@@ -747,14 +748,26 @@ def _apply_window_op(
747748
window_spec = window or window_specs.cumulative_rows(
748749
grouping_keys=tuple(self._by_col_ids)
749750
)
750-
columns, _ = self._aggregated_columns(numeric_only=numeric_only)
751+
columns, labels = self._aggregated_columns(numeric_only=numeric_only)
751752
block, result_ids = self._block.multi_apply_window_op(
752753
columns,
753754
op,
754755
window_spec=window_spec,
755756
)
756-
result = df.DataFrame(block.select_columns(result_ids))
757-
return result
757+
block = block.project_exprs(
758+
tuple(
759+
bigframes.operations.where_op.as_expr(
760+
r_col,
761+
bigframes.operations.notnull_op.as_expr(og_col),
762+
ex.const(None),
763+
)
764+
for og_col, r_col in zip(columns, result_ids)
765+
),
766+
labels=labels,
767+
drop=True,
768+
)
769+
770+
return df.DataFrame(block)
758771

759772
def _resolve_label(self, label: blocks.Label) -> str:
760773
"""Resolve label to column id."""

bigframes/core/groupby/series_group_by.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import bigframes.core.window_spec as window_specs
3838
import bigframes.dataframe as df
3939
import bigframes.dtypes
40+
import bigframes.operations
4041
import bigframes.operations.aggregations as agg_ops
4142
import bigframes.series as series
4243

@@ -339,7 +340,6 @@ def cumcount(self, *args, **kwargs) -> series.Series:
339340
self._apply_window_op(
340341
agg_ops.SizeUnaryOp(),
341342
discard_name=True,
342-
never_skip_nulls=True,
343343
)
344344
- 1
345345
)
@@ -426,7 +426,6 @@ def _apply_window_op(
426426
op: agg_ops.UnaryWindowOp,
427427
discard_name=False,
428428
window: typing.Optional[window_specs.WindowSpec] = None,
429-
never_skip_nulls: bool = False,
430429
) -> series.Series:
431430
"""Apply window op to groupby. Defaults to grouped cumulative window."""
432431
window_spec = window or window_specs.cumulative_rows(
@@ -439,6 +438,15 @@ def _apply_window_op(
439438
op,
440439
result_label=label,
441440
window_spec=window_spec,
442-
never_skip_nulls=never_skip_nulls,
443441
)
442+
if op.skips_nulls:
443+
block, result_id = block.project_expr(
444+
bigframes.operations.where_op.as_expr(
445+
result_id,
446+
bigframes.operations.notnull_op.as_expr(self._value_column),
447+
ex.const(None),
448+
),
449+
label,
450+
)
451+
444452
return series.Series(block.select_column(result_id))

bigframes/core/nodes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,6 @@ class WindowOpNode(UnaryNode, AdditiveNode):
13941394
expression: agg_expressions.Aggregation
13951395
window_spec: window.WindowSpec
13961396
output_name: identifiers.ColumnId
1397-
never_skip_nulls: bool = False
13981397
skip_reproject_unsafe: bool = False
13991398

14001399
def _validate(self):

bigframes/core/rewrite/timedeltas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod
6767
_rewrite_aggregation(root.expression, root.schema),
6868
root.window_spec,
6969
root.output_name,
70-
root.never_skip_nulls,
7170
root.skip_reproject_unsafe,
7271
)
7372

@@ -112,6 +111,8 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty
112111

113112

114113
def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedExpr:
114+
if expr.value is None:
115+
return _TypedExpr(ex.const(None, expr.dtype), expr.dtype)
115116
if expr.dtype == dtypes.TIMEDELTA_DTYPE:
116117
int_repr = utils.timedelta_to_micros(expr.value) # type: ignore
117118
return _TypedExpr(ex.const(int_repr, expr.dtype), expr.dtype)

0 commit comments

Comments
 (0)