diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 93f072973c..ff2bb32e3c 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -242,6 +242,28 @@ def compile_join( joins_nulls=node.joins_nulls, ) + @_compile_node.register + def compile_isin_join( + self, node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + conditions = tuple( + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(node.left_col), + left.output_type, + ), + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(node.right_col), + right.output_type, + ), + ) + + return left.isin_join( + right, + indicator_col=node.indicator_col.sql, + conditions=conditions, + joins_nulls=node.joins_nulls, + ) + @_compile_node.register def compile_concat( self, node: nodes.ConcatNode, *children: ir.SQLGlotIR diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index c0bed4090c..ffdbab8b79 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -332,6 +332,56 @@ def join( return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def isin_join( + self, + right: SQLGlotIR, + indicator_col: str, + conditions: tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], + *, + joins_nulls: bool = True, + ) -> SQLGlotIR: + """Joins the current query with another SQLGlotIR instance.""" + # TODO: Optimization similar to Ibis: + # if isinstance(values, ArrayValue): + # return ops.ArrayContains(values, self).to_expr() + # elif isinstance(values, Column): + # return ops.InSubquery(values.as_table(), needle=self).to_expr() + # else: + # return ops.InValues(self, values).to_expr() + + raise NotImplementedError + # left_cte_name = sge.to_identifier( + # next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + # ) + # right_cte_name = sge.to_identifier( + # next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + # ) + + # left_select = _select_to_cte(self.expr, left_cte_name) + # right_select = _select_to_cte(right.expr, right_cte_name) + + # left_ctes = left_select.args.pop("with", []) + # right_ctes = right_select.args.pop("with", []) + # merged_ctes = [*left_ctes, *right_ctes] + + + + # join_conditions = [ + # _join_condition(left, right, joins_nulls) for left, right in conditions + # ] + # join_on = sge.And(expressions=join_conditions) if join_conditions else None + + # join_type_str = join_type if join_type != "outer" else "full outer" + # new_expr = ( + # sge.Select() + # .select(sge.Star()) + # .from_(sge.Table(this=left_cte_name)) + # .join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str) + # ) + # new_expr.set("with", sge.With(expressions=merged_ctes)) + + # return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def explode( self, column_names: tuple[str, ...], diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin/out.sql new file mode 100644 index 0000000000..b9cced3226 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin/out.sql @@ -0,0 +1,10 @@ +WITH `bfcte_0` AS ( + SELECT + * + FROM UNNEST(ARRAY>[STRUCT(314159.0, 0), STRUCT(2.0, 1), STRUCT(3.0, 2), STRUCT(CAST(NULL AS FLOAT64), 3)]) +) +SELECT + `bfcol_0` AS `0` +FROM `bfcte_0` +ORDER BY + `bfcol_1` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_isin.py b/tests/unit/core/compile/sqlglot/test_compile_isin.py new file mode 100644 index 0000000000..6a022873d2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_isin.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import pytest + +import bigframes +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_isin( + scalar_types_df: bpd.DataFrame, compiler_session: bigframes.Session, snapshot +): + data = [314159, 2.0, 3, pd.NA] + s = bpd.Series(data, session=compiler_session) + bf_isin = scalar_types_df["int64_col"].isin(s).to_frame() + snapshot.assert_match(bf_isin.sql, "out.sql")