Skip to content

Commit 92ad3d4

Browse files
committed
fix window errors
1 parent 55fca3d commit 92ad3d4

File tree

6 files changed

+28
-54
lines changed

6 files changed

+28
-54
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 23 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _(
121121
case_expr = _cut_ops_w_int_bins(op, column, op.bins, window)
122122
else: # Interpret as intervals
123123
case_expr = _cut_ops_w_intervals(op, column, op.bins, window)
124-
return apply_window_if_present(case_expr, window)
124+
return case_expr
125125

126126

127127
def _cut_ops_w_int_bins(
@@ -165,31 +165,18 @@ def _cut_ops_w_int_bins(
165165
col_min + sge.convert(this_bin + 1) * bin_width + right_adj
166166
)
167167
if op.right:
168-
value = sge.Struct(
169-
expressions=[
170-
sge.PropertyEQ(
171-
this=sge.Identifier(this="left_exclusive", quoted=True),
172-
expression=left,
173-
),
174-
sge.PropertyEQ(
175-
this=sge.Identifier(this="right_inclusive", quoted=True),
176-
expression=right,
177-
),
178-
]
179-
)
168+
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
169+
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
180170
else:
181-
value = sge.Struct(
182-
expressions=[
183-
sge.PropertyEQ(
184-
this=sge.Identifier(this="left_inclusive", quoted=True),
185-
expression=left,
186-
),
187-
sge.PropertyEQ(
188-
this=sge.Identifier(this="right_exclusive", quoted=True),
189-
expression=right,
190-
),
191-
]
192-
)
171+
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
172+
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)
173+
174+
value = sge.Struct(
175+
expressions=[
176+
sge.PropertyEQ(this=left_identifier, expression=left),
177+
sge.PropertyEQ(this=right_identifier, expression=right),
178+
]
179+
)
193180

194181
condition: sge.Expression
195182
if this_bin == bins - 1:
@@ -242,31 +229,18 @@ def _cut_ops_w_intervals(
242229
value = ir._literal(list(op.labels)[this_bin], dtypes.STRING_DTYPE)
243230
else:
244231
if op.right:
245-
value = sge.Struct(
246-
expressions=[
247-
sge.PropertyEQ(
248-
this=sge.Identifier(this="left_exclusive", quoted=True),
249-
expression=left,
250-
),
251-
sge.PropertyEQ(
252-
this=sge.Identifier(this="right_inclusive", quoted=True),
253-
expression=right,
254-
),
255-
]
256-
)
232+
left_identifier = sge.Identifier(this="left_exclusive", quoted=True)
233+
right_identifier = sge.Identifier(this="right_inclusive", quoted=True)
257234
else:
258-
value = sge.Struct(
259-
expressions=[
260-
sge.PropertyEQ(
261-
this=sge.Identifier(this="left_inclusive", quoted=True),
262-
expression=left,
263-
),
264-
sge.PropertyEQ(
265-
this=sge.Identifier(this="right_exclusive", quoted=True),
266-
expression=right,
267-
),
268-
]
269-
)
235+
left_identifier = sge.Identifier(this="left_inclusive", quoted=True)
236+
right_identifier = sge.Identifier(this="right_exclusive", quoted=True)
237+
238+
value = sge.Struct(
239+
expressions=[
240+
sge.PropertyEQ(this=left_identifier, expression=left),
241+
sge.PropertyEQ(this=right_identifier, expression=right),
242+
]
243+
)
270244
case_expr = case_expr.when(condition, value)
271245
return case_expr
272246

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ WITH `bfcte_0` AS (
4747
3 * IEEE_DIVIDE(MAX(`int64_col`) OVER () - MIN(`int64_col`) OVER (), 3)
4848
) + 0 AS `right_inclusive`
4949
)
50-
END OVER () AS `bfcol_1`
50+
END AS `bfcol_1`
5151
FROM `bfcte_0`
5252
)
5353
SELECT

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/int_bins_labels.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ WITH `bfcte_0` AS (
1616
THEN 'b'
1717
WHEN `int64_col` IS NOT NULL
1818
THEN 'c'
19-
END OVER () AS `bfcol_1`
19+
END AS `bfcol_1`
2020
FROM `bfcte_0`
2121
)
2222
SELECT

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ WITH `bfcte_0` AS (
1010
THEN STRUCT(0 AS `left_exclusive`, 1 AS `right_inclusive`)
1111
WHEN `int64_col` > 1 AND `int64_col` <= 2
1212
THEN STRUCT(1 AS `left_exclusive`, 2 AS `right_inclusive`)
13-
END OVER () AS `bfcol_1`
13+
END AS `bfcol_1`
1414
FROM `bfcte_0`
1515
)
1616
SELECT

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_cut/interval_bins_labels.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ WITH `bfcte_0` AS (
1010
THEN 0
1111
WHEN `int64_col` > 1 AND `int64_col` <= 2
1212
THEN 1
13-
END OVER () AS `bfcol_1`
13+
END AS `bfcol_1`
1414
FROM `bfcte_0`
1515
)
1616
SELECT

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def test_cut(scalar_types_df: bpd.DataFrame, snapshot):
199199
# Loop through the aggregation map items
200200
for test_name, agg_expr in agg_ops_map.items():
201201
sql = _apply_unary_window_op(bf_df, agg_expr, window, test_name)
202-
202+
203203
snapshot.assert_match(sql, f"{test_name}.sql")
204204

205205

0 commit comments

Comments
 (0)