diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index b47637cb59..08b528a9e8 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -400,32 +400,6 @@ def aggregate( ) ) - def project_window_op( - self, - column_name: str, - op: agg_ops.UnaryWindowOp, - window_spec: WindowSpec, - *, - never_skip_nulls=False, - skip_reproject_unsafe: bool = False, - ) -> Tuple[ArrayValue, str]: - """ - Creates a new expression based on this expression with unary operation applied to one column. - column_name: the id of the input column present in the expression - op: the windowable operator to apply to the input column - window_spec: a specification of the window over which to apply the operator - output_name: the id to assign to the output of the operator, by default will replace input col if distinct output id not provided - never_skip_nulls: will disable null skipping for operators that would otherwise do so - skip_reproject_unsafe: skips the reprojection step, can be used when performing many non-dependent window operations, user responsible for not nesting window expressions, or using outputs as join, filter or aggregation keys before a reprojection - """ - - return self.project_window_expr( - ex.UnaryAggregation(op, ex.deref(column_name)), - window_spec, - never_skip_nulls, - skip_reproject_unsafe, - ) - def project_window_expr( self, expression: ex.Aggregation, diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 93f072973c..be38dbcce6 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -16,6 +16,7 @@ import dataclasses import functools import typing +import itertools from google.cloud import bigquery import sqlglot.expressions as sge @@ -219,6 +220,37 @@ def compile_filter( condition = scalar_compiler.compile_scalar_expression(node.predicate) return child.filter(condition) + @_compile_node.register + def compile_window(self, node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + column_references: tuple[sge.Expression, ...] = tuple( + scalar_compiler.compile_scalar_expression(expression.DerefOp(column)) + for column in expression.column_references + ) + + # TODO: can_directly_window = not any(map(lambda x: is_window(x), used_exprs)) + # used_exprs = map( + # scalar_compiler.compile_scalar_expression, + # map( + # expression.DerefOp, + # itertools.chain( + # node.expression.column_references, node.window_spec.all_referenced_columns + # ), + # ), + # ) + # can_directly_window = False + + window_spec = node.window_spec + if node.expression.op.order_independent and window_spec.is_unbounded: + # notably percentile_cont does not support ordering clause + window_spec = window_spec.without_order() + + return child.window( + column_references = column_references, + window_spec = node.window_spec, + output_column = scalar_compiler.compile_scalar_expression(node.output_name), + skip_nulls = node.expression.op.skips_nulls and not node.never_skip_nulls + ) + @_compile_node.register def compile_join( self, node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index c0bed4090c..ddb0dcc82e 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -25,11 +25,9 @@ import sqlglot.expressions as sge from bigframes import dtypes -from bigframes.core import guid, utils +from bigframes.core import guid, utils, window_spec, local_data, schema from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt -import bigframes.core.local_data as local_data -import bigframes.core.schema as bf_schema # shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. try: @@ -68,7 +66,7 @@ def sql(self) -> str: def from_pyarrow( cls, pa_table: pa.Table, - schema: bf_schema.ArraySchema, + schema: schema.ArraySchema, uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Builds SQLGlot expression from a pyarrow table. @@ -293,13 +291,28 @@ def filter( expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen ) + def window( + self, + column_references: tuple[sge.Expression, ...], + window_spec: window_spec.WindowSpec, + output_name: sge.Expression, + skip_nulls: bool, + ) -> SQLGlotIR: + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) + + return self + def join( self, right: SQLGlotIR, join_type: typing.Literal["inner", "outer", "left", "right", "cross"], conditions: tuple[tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], ...], - *, - joins_nulls: bool = True, + joins_nulls: bool, ) -> SQLGlotIR: """Joins the current query with another SQLGlotIR instance.""" left_cte_name = sge.to_identifier( diff --git a/tests/unit/core/compile/sqlglot/test_compile_window.py b/tests/unit/core/compile/sqlglot/test_compile_window.py new file mode 100644 index 0000000000..8eed2172e8 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_window.py @@ -0,0 +1,26 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_window(scalar_types_df: bpd.DataFrame, snapshot): + + bf_df = scalar_types_df[["int64_col"]].sort_index() + result = bf_df.diff() + snapshot.assert_match(result.sql, "out.sql") \ No newline at end of file