Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,37 +401,10 @@ def aggregate(
)
)

def project_window_op(
self,
column_name: str,
op: agg_ops.UnaryWindowOp,
window_spec: WindowSpec,
*,
never_skip_nulls=False,
skip_reproject_unsafe: bool = False,
) -> Tuple[ArrayValue, str]:
"""
Creates a new expression based on this expression with unary operation applied to one column.
column_name: the id of the input column present in the expression
op: the windowable operator to apply to the input column
window_spec: a specification of the window over which to apply the operator
output_name: the id to assign to the output of the operator, by default will replace input col if distinct output id not provided
never_skip_nulls: will disable null skipping for operators that would otherwise do so
skip_reproject_unsafe: skips the reprojection step, can be used when performing many non-dependent window operations, user responsible for not nesting window expressions, or using outputs as join, filter or aggregation keys before a reprojection
"""

return self.project_window_expr(
agg_expressions.UnaryAggregation(op, ex.deref(column_name)),
window_spec,
never_skip_nulls,
skip_reproject_unsafe,
)

def project_window_expr(
self,
expression: agg_expressions.Aggregation,
window: WindowSpec,
never_skip_nulls=False,
skip_reproject_unsafe: bool = False,
):
output_name = self._gen_namespaced_uid()
Expand All @@ -442,7 +415,6 @@ def project_window_expr(
expression=expression,
window_spec=window,
output_name=ids.ColumnId(output_name),
never_skip_nulls=never_skip_nulls,
skip_reproject_unsafe=skip_reproject_unsafe,
)
),
Expand Down
6 changes: 0 additions & 6 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,6 @@ def multi_apply_window_op(
window_spec: windows.WindowSpec,
*,
skip_null_groups: bool = False,
never_skip_nulls: bool = False,
) -> typing.Tuple[Block, typing.Sequence[str]]:
block = self
result_ids = []
Expand All @@ -1103,7 +1102,6 @@ def multi_apply_window_op(
skip_reproject_unsafe=(i + 1) < len(columns),
result_label=label,
skip_null_groups=skip_null_groups,
never_skip_nulls=never_skip_nulls,
)
result_ids.append(result_id)
return block, result_ids
Expand Down Expand Up @@ -1184,15 +1182,13 @@ def apply_window_op(
result_label: Label = None,
skip_null_groups: bool = False,
skip_reproject_unsafe: bool = False,
never_skip_nulls: bool = False,
) -> typing.Tuple[Block, str]:
agg_expr = agg_expressions.UnaryAggregation(op, ex.deref(column))
return self.apply_analytic(
agg_expr,
window_spec,
result_label,
skip_reproject_unsafe=skip_reproject_unsafe,
never_skip_nulls=never_skip_nulls,
skip_null_groups=skip_null_groups,
)

Expand All @@ -1203,7 +1199,6 @@ def apply_analytic(
result_label: Label,
*,
skip_reproject_unsafe: bool = False,
never_skip_nulls: bool = False,
skip_null_groups: bool = False,
) -> typing.Tuple[Block, str]:
block = self
Expand All @@ -1214,7 +1209,6 @@ def apply_analytic(
agg_expr,
window,
skip_reproject_unsafe=skip_reproject_unsafe,
never_skip_nulls=never_skip_nulls,
)
block = Block(
expr,
Expand Down
7 changes: 0 additions & 7 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,16 +394,13 @@ def project_window_op(
expression: ex_types.Aggregation,
window_spec: WindowSpec,
output_name: str,
*,
never_skip_nulls=False,
) -> UnorderedIR:
"""
Creates a new expression based on this expression with unary operation applied to one column.
column_name: the id of the input column present in the expression
op: the windowable operator to apply to the input column
window_spec: a specification of the window over which to apply the operator
output_name: the id to assign to the output of the operator
never_skip_nulls: will disable null skipping for operators that would otherwise do so
"""
# Cannot nest analytic expressions, so reproject to cte first if needed.
# Also ibis cannot window literals, so need to reproject those (even though this is legal in googlesql)
Expand All @@ -425,7 +422,6 @@ def project_window_op(
expression,
window_spec,
output_name,
never_skip_nulls=never_skip_nulls,
)

if expression.op.order_independent and window_spec.is_unbounded:
Expand All @@ -437,9 +433,6 @@ def project_window_op(
expression, window_spec
)
clauses: list[tuple[ex.Expression, ex.Expression]] = []
if expression.op.skips_nulls and not never_skip_nulls:
for input in expression.inputs:
clauses.append((ops.isnull_op.as_expr(input), ex.const(None)))
if window_spec.min_periods and len(expression.inputs) > 0:
if not expression.op.nulls_count_for_min_values:
is_observation = ops.notnull_op.as_expr()
Expand Down
1 change: 0 additions & 1 deletion bigframes/core/compile/ibis_compiler/ibis_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,6 @@ def compile_window(node: nodes.WindowOpNode, child: compiled.UnorderedIR):
node.expression,
node.window_spec,
node.output_name.sql,
never_skip_nulls=node.never_skip_nulls,
)
return result

Expand Down
21 changes: 0 additions & 21 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import dataclasses
import functools
import itertools
import operator
from typing import cast, Literal, Optional, Sequence, Tuple, Type, TYPE_CHECKING

import pandas as pd
Expand Down Expand Up @@ -868,26 +867,6 @@ def compile_window(self, node: nodes.WindowOpNode):
df, node.expression, node.window_spec, node.output_name.sql
)
result = pl.concat([df, window_result], how="horizontal")

# Probably easier just to pull this out as a rewriter
if (
node.expression.op.skips_nulls
and not node.never_skip_nulls
and node.expression.column_references
):
nullity_expr = functools.reduce(
operator.or_,
(
pl.col(column.sql).is_null()
for column in node.expression.column_references
),
)
result = result.with_columns(
pl.when(nullity_expr)
.then(None)
.otherwise(pl.col(node.output_name.sql))
.alias(node.output_name.sql)
)
return result

def _calc_row_analytic_func(
Expand Down
6 changes: 1 addition & 5 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,8 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
)

clauses: list[tuple[sge.Expression, sge.Expression]] = []
if node.expression.op.skips_nulls and not node.never_skip_nulls:
for column in inputs:
clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null()))

if window_spec.min_periods and len(inputs) > 0:
if node.expression.op.skips_nulls:
if not node.expression.op.nulls_count_for_min_values:
# Most operations do not count NULL values towards min_periods
not_null_columns = [
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
Expand Down
19 changes: 16 additions & 3 deletions bigframes/core/groupby/dataframe_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import bigframes.core.window_spec as window_specs
import bigframes.dataframe as df
import bigframes.dtypes as dtypes
import bigframes.operations
import bigframes.operations.aggregations as agg_ops
import bigframes.series as series

Expand Down Expand Up @@ -747,14 +748,26 @@ def _apply_window_op(
window_spec = window or window_specs.cumulative_rows(
grouping_keys=tuple(self._by_col_ids)
)
columns, _ = self._aggregated_columns(numeric_only=numeric_only)
columns, labels = self._aggregated_columns(numeric_only=numeric_only)
block, result_ids = self._block.multi_apply_window_op(
columns,
op,
window_spec=window_spec,
)
result = df.DataFrame(block.select_columns(result_ids))
return result
block = block.project_exprs(
tuple(
bigframes.operations.where_op.as_expr(
r_col,
bigframes.operations.notnull_op.as_expr(og_col),
ex.const(None),
)
for og_col, r_col in zip(columns, result_ids)
),
labels=labels,
drop=True,
)

return df.DataFrame(block)

def _resolve_label(self, label: blocks.Label) -> str:
"""Resolve label to column id."""
Expand Down
14 changes: 11 additions & 3 deletions bigframes/core/groupby/series_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import bigframes.core.window_spec as window_specs
import bigframes.dataframe as df
import bigframes.dtypes
import bigframes.operations
import bigframes.operations.aggregations as agg_ops
import bigframes.series as series

Expand Down Expand Up @@ -339,7 +340,6 @@ def cumcount(self, *args, **kwargs) -> series.Series:
self._apply_window_op(
agg_ops.SizeUnaryOp(),
discard_name=True,
never_skip_nulls=True,
)
- 1
)
Expand Down Expand Up @@ -426,7 +426,6 @@ def _apply_window_op(
op: agg_ops.UnaryWindowOp,
discard_name=False,
window: typing.Optional[window_specs.WindowSpec] = None,
never_skip_nulls: bool = False,
) -> series.Series:
"""Apply window op to groupby. Defaults to grouped cumulative window."""
window_spec = window or window_specs.cumulative_rows(
Expand All @@ -439,6 +438,15 @@ def _apply_window_op(
op,
result_label=label,
window_spec=window_spec,
never_skip_nulls=never_skip_nulls,
)
if op.skips_nulls:
block, result_id = block.project_expr(
bigframes.operations.where_op.as_expr(
result_id,
bigframes.operations.notnull_op.as_expr(self._value_column),
ex.const(None),
),
label,
)

return series.Series(block.select_column(result_id))
1 change: 0 additions & 1 deletion bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,6 @@ class WindowOpNode(UnaryNode, AdditiveNode):
expression: agg_expressions.Aggregation
window_spec: window.WindowSpec
output_name: identifiers.ColumnId
never_skip_nulls: bool = False
skip_reproject_unsafe: bool = False

def _validate(self):
Expand Down
3 changes: 2 additions & 1 deletion bigframes/core/rewrite/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod
_rewrite_aggregation(root.expression, root.schema),
root.window_spec,
root.output_name,
root.never_skip_nulls,
root.skip_reproject_unsafe,
)

Expand Down Expand Up @@ -112,6 +111,8 @@ def _rewrite_expressions(expr: ex.Expression, schema: schema.ArraySchema) -> _Ty


def _rewrite_scalar_constant_expr(expr: ex.ScalarConstantExpression) -> _TypedExpr:
if expr.value is None:
return _TypedExpr(ex.const(None, expr.dtype), expr.dtype)
if expr.dtype == dtypes.TIMEDELTA_DTYPE:
int_repr = utils.timedelta_to_micros(expr.value) # type: ignore
return _TypedExpr(ex.const(int_repr, expr.dtype), expr.dtype)
Expand Down
1 change: 0 additions & 1 deletion bigframes/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def _aggregate_block(self, op: agg_ops.UnaryAggregateOp) -> blocks.Block:
op,
self._window_spec,
skip_null_groups=self._drop_null_groups,
never_skip_nulls=True,
)

if self._window_spec.grouping_keys:
Expand Down
17 changes: 16 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4159,7 +4159,22 @@ def _apply_window_op(
op,
window_spec=window_spec,
)
return DataFrame(block.select_columns(result_ids))
if op.skips_nulls:
block = block.project_exprs(
tuple(
bigframes.operations.where_op.as_expr(
r_col,
bigframes.operations.notnull_op.as_expr(og_col),
ex.const(None),
)
for og_col, r_col in zip(self._block.value_columns, result_ids)
),
labels=self._block.column_labels,
drop=True,
)
else:
block = block.select_columns(result_ids)
return DataFrame(block)

@validations.requires_ordering()
def sample(
Expand Down
4 changes: 4 additions & 0 deletions bigframes/operations/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def output_type(self, *input_types: dtypes.ExpressionType):
class SizeUnaryOp(UnaryAggregateOp):
name: ClassVar[str] = "size"

@property
def skips_nulls(self):
return False

def output_type(self, *input_types: dtypes.ExpressionType):
return dtypes.INT_DTYPE

Expand Down
6 changes: 5 additions & 1 deletion bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,11 @@ def _apply_window_op(
block, result_id = block.apply_window_op(
self._value_column, op, window_spec=window_spec, result_label=self.name
)
return Series(block.select_column(result_id))
result = Series(block.select_column(result_id))
if op.skips_nulls:
return result.where(self.notna(), None)
else:
return result

def value_counts(
self,
Expand Down
3 changes: 0 additions & 3 deletions tests/system/small/engines/test_windowing.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ def test_engines_with_offsets(
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("never_skip_nulls", [True, False])
@pytest.mark.parametrize("agg_op", [agg_ops.sum_op, agg_ops.count_op])
def test_engines_with_rows_window(
scalars_array_value: array_value.ArrayValue,
bigquery_client: bigquery.Client,
never_skip_nulls,
agg_op,
):
window = window_spec.WindowSpec(
Expand All @@ -61,7 +59,6 @@ def test_engines_with_rows_window(
),
window_spec=window,
output_name=identifiers.ColumnId("agg_int64"),
never_skip_nulls=never_skip_nulls,
skip_reproject_unsafe=False,
)

Expand Down
10 changes: 5 additions & 5 deletions tests/system/small/session/test_read_gbq_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def test_read_gbq_query_w_allow_large_results(session: bigframes.Session):
allow_large_results=False,
)
assert df_false.shape == (1, 1)
roots_false = df_false._get_block().expr.node.roots
assert any(isinstance(node, nodes.ReadLocalNode) for node in roots_false)
assert not any(isinstance(node, nodes.ReadTableNode) for node in roots_false)
nodes_false = df_false._get_block().expr.node.unique_nodes()
assert any(isinstance(node, nodes.ReadLocalNode) for node in nodes_false)
assert not any(isinstance(node, nodes.ReadTableNode) for node in nodes_false)

# Large results allowed should wrap a table.
df_true = session.read_gbq(
Expand All @@ -47,8 +47,8 @@ def test_read_gbq_query_w_allow_large_results(session: bigframes.Session):
allow_large_results=True,
)
assert df_true.shape == (1, 1)
roots_true = df_true._get_block().expr.node.roots
assert any(isinstance(node, nodes.ReadTableNode) for node in roots_true)
nodes_true = df_true._get_block().expr.node.unique_nodes()
assert any(isinstance(node, nodes.ReadTableNode) for node in nodes_true)


def test_read_gbq_query_w_columns(session: bigframes.Session):
Expand Down
Loading