Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,15 @@ def project_window_expr(
)

def isin(
self, other: ArrayValue, lcol: str, rcol: str
self,
other: ArrayValue,
lcol: str,
) -> typing.Tuple[ArrayValue, str]:
assert len(other.column_ids) == 1
node = nodes.InNode(
self.node,
other.node,
ex.deref(lcol),
ex.deref(rcol),
indicator_col=ids.ColumnId.unique(),
)
return ArrayValue(node), node.indicator_col.name
Expand Down
2 changes: 1 addition & 1 deletion bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2326,7 +2326,7 @@ def isin(self, other: Block):
return block

def _isin_inner(self: Block, col: str, unique_values: core.ArrayValue) -> Block:
expr, matches = self._expr.isin(unique_values, col, unique_values.column_ids[0])
expr, matches = self._expr.isin(unique_values, col)

new_value_cols = tuple(
val_col if val_col != col else matches for val_col in self.value_columns
Expand Down
2 changes: 1 addition & 1 deletion bigframes/core/compile/ibis_compiler/ibis_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def compile_isin(
return left.isin_join(
right=right,
indicator_col=node.indicator_col.sql,
conditions=(node.left_col.id.sql, node.right_col.id.sql),
conditions=(node.left_col.id.sql, list(node.right_child.ids)[0].sql),
join_nulls=node.joins_nulls,
)

Expand Down
7 changes: 3 additions & 4 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,12 +700,11 @@ def compile_join(self, node: nodes.JoinNode):
@compile_node.register
def compile_isin(self, node: nodes.InNode):
left = self.compile_node(node.left_child)
right = self.compile_node(node.right_child).unique(node.right_col.id.sql)
right = self.compile_node(node.right_child).unique()
right = right.with_columns(pl.lit(True).alias(node.indicator_col.sql))

left_ex, right_ex = lowering._coerce_comparables(
node.left_col, node.right_col
)
right_col = ex.ResolvedDerefOp.from_field(node.right_child.fields[0])
left_ex, right_ex = lowering._coerce_comparables(node.left_col, right_col)

left_pl_ex = self.expr_compiler.compile_expression(left_ex)
right_pl_ex = self.expr_compiler.compile_expression(right_ex)
Expand Down
7 changes: 5 additions & 2 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,17 @@ def compile_join(
def compile_isin_join(
node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
) -> ir.SQLGlotIR:
right_field = node.right_child.fields[0]
conditions = (
typed_expr.TypedExpr(
scalar_compiler.scalar_op_compiler.compile_expression(node.left_col),
node.left_col.output_type,
),
typed_expr.TypedExpr(
scalar_compiler.scalar_op_compiler.compile_expression(node.right_col),
node.right_col.output_type,
scalar_compiler.scalar_op_compiler.compile_expression(
expression.DerefOp(right_field.id)
),
right_field.dtype,
),
)

Expand Down
19 changes: 9 additions & 10 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,10 @@ class InNode(BigFrameNode, AdditiveNode):
left_child: BigFrameNode
right_child: BigFrameNode
left_col: ex.DerefOp
right_col: ex.DerefOp
indicator_col: identifiers.ColumnId

def _validate(self):
assert not (
set(self.left_child.ids) & set(self.right_child.ids)
), "Join ids collide"
assert len(self.right_child.fields) == 1

@property
def row_preserving(self) -> bool:
Expand Down Expand Up @@ -259,7 +256,11 @@ def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:

@property
def referenced_ids(self) -> COLUMN_SET:
return frozenset({self.left_col.id, self.right_col.id})
return frozenset(
{
self.left_col.id,
}
)

@property
def additive_base(self) -> BigFrameNode:
Expand All @@ -268,12 +269,13 @@ def additive_base(self) -> BigFrameNode:
@property
def joins_nulls(self) -> bool:
left_nullable = self.left_child.field_by_id[self.left_col.id].nullable
right_nullable = self.right_child.field_by_id[self.right_col.id].nullable
# assumption: right side has one column
right_nullable = self.right_child.fields[0].nullable
return left_nullable or right_nullable

@property
def _node_expressions(self):
return (self.left_col, self.right_col)
return (self.left_col,)

def replace_additive_base(self, node: BigFrameNode):
return dataclasses.replace(self, left_child=node)
Expand Down Expand Up @@ -302,9 +304,6 @@ def remap_refs(
left_col=self.left_col.remap_column_refs(
mappings, allow_partial_bindings=True
),
right_col=self.right_col.remap_column_refs(
mappings, allow_partial_bindings=True
),
) # type: ignore


Expand Down
3 changes: 0 additions & 3 deletions bigframes/core/rewrite/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def remap_variables(
left_col=new_root.left_col.remap_column_refs(
new_child_mappings[0], allow_partial_bindings=True
),
right_col=new_root.right_col.remap_column_refs(
new_child_mappings[1], allow_partial_bindings=True
),
)
else:
new_root = new_root.remap_refs(downstream_mappings)
Expand Down
31 changes: 1 addition & 30 deletions bigframes/core/rewrite/implicit_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import dataclasses
import itertools
from typing import cast, Optional, Sequence, Set, Tuple
from typing import Optional, Sequence, Set, Tuple

import bigframes.core.expression
import bigframes.core.identifiers
Expand Down Expand Up @@ -152,35 +152,6 @@ def pull_up_selection(
return node, tuple(
bigframes.core.nodes.AliasedRef.identity(field.id) for field in node.fields
)
# InNode needs special handling, as its a binary node, but row identity is from left side only.
# TODO: Merge code with unary op paths
if isinstance(node, bigframes.core.nodes.InNode):
child_node, child_selections = pull_up_selection(
node.left_child, stop=stop, rename_vars=rename_vars
)
mapping = {out: ref.id for ref, out in child_selections}

new_in_node: bigframes.core.nodes.InNode = dataclasses.replace(
node, left_child=child_node
)
new_in_node = new_in_node.remap_refs(mapping)
if rename_vars:
new_in_node = cast(
bigframes.core.nodes.InNode,
new_in_node.remap_vars(
{node.indicator_col: bigframes.core.identifiers.ColumnId.unique()}
),
)
added_selection = tuple(
(
bigframes.core.nodes.AliasedRef(
bigframes.core.expression.DerefOp(new_in_node.indicator_col),
node.indicator_col,
),
)
)
new_selection = child_selections + added_selection
return new_in_node, new_selection

if isinstance(node, bigframes.core.nodes.AdditiveNode):
child_node, child_selections = pull_up_selection(
Expand Down
5 changes: 0 additions & 5 deletions bigframes/core/rewrite/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ def prune_columns(node: nodes.BigFrameNode):
result = node.replace_child(prune_node(node.child, node.consumed_ids))
elif isinstance(node, nodes.AggregateNode):
result = node.replace_child(prune_node(node.child, node.consumed_ids))
elif isinstance(node, nodes.InNode):
result = dataclasses.replace(
node,
right_child=prune_node(node.right_child, frozenset([node.right_col.id])),
)
else:
result = node
return result
Expand Down
3 changes: 0 additions & 3 deletions bigframes/core/rewrite/schema_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def bind_schema_to_node(
left_col=ex.ResolvedDerefOp.from_field(
node.left_child.field_by_id[node.left_col.id]
),
right_col=ex.ResolvedDerefOp.from_field(
node.right_child.field_by_id[node.right_col.id]
),
)

if isinstance(node, nodes.AggregateNode):
Expand Down
4 changes: 3 additions & 1 deletion tests/system/small/engines/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ def test_engines_cross_join(
def test_engines_isin(
scalars_array_value: array_value.ArrayValue, engine, left_key, right_key
):
other = scalars_array_value.select_columns([right_key])
result, _ = scalars_array_value.isin(
scalars_array_value, lcol=left_key, rcol=right_key
other,
lcol=left_key,
)

assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
10 changes: 5 additions & 5 deletions tests/unit/core/rewrite/test_identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,13 @@ def test_remap_variables_concat_self_stability(leaf):

def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too):
# Create an InNode with the same child twice, should create a tree from a DAG
right = nodes.SelectionNode(
leaf_too, (nodes.AliasedRef.identity(identifiers.ColumnId("col_a")),)
)
node = nodes.InNode(
left_child=leaf,
right_child=leaf_too,
right_child=right,
left_col=ex.DerefOp(identifiers.ColumnId("col_a")),
right_col=ex.DerefOp(identifiers.ColumnId("col_a")),
indicator_col=identifiers.ColumnId("indicator"),
)

Expand All @@ -147,7 +149,5 @@ def test_remap_variables_in_node_converts_dag_to_tree(leaf, leaf_too):
new_node = typing.cast(nodes.InNode, new_node)

left_col_id = new_node.left_col.id.name
right_col_id = new_node.right_col.id.name
new_node.validate_tree()
assert left_col_id.startswith("id_")
assert right_col_id.startswith("id_")
assert left_col_id != right_col_id