From f4a818f5d3c341b512b08c4adc5238f1e6d7a575 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 14 May 2025 17:45:08 +0000 Subject: [PATCH] refactor: DerefOp.ColumnId -> DerefOp.Field --- bigframes/core/array_value.py | 23 +++++++++----- bigframes/core/compile/compiler.py | 4 ++- bigframes/core/compile/sqlglot/compiler.py | 2 +- bigframes/core/expression.py | 8 ++++- bigframes/core/nodes.py | 4 +-- bigframes/core/rewrite/implicit_align.py | 4 +-- bigframes/core/rewrite/legacy_align.py | 7 ++-- bigframes/core/rewrite/order.py | 37 +++++++++++++++------- bigframes/core/rewrite/pruning.py | 12 +++++-- bigframes/core/rewrite/slices.py | 6 ++-- 10 files changed, 72 insertions(+), 35 deletions(-) diff --git a/bigframes/core/array_value.py b/bigframes/core/array_value.py index 60f5315554..69c9a968fb 100644 --- a/bigframes/core/array_value.py +++ b/bigframes/core/array_value.py @@ -284,8 +284,14 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue: if destination_id in self.column_ids: # Mutate case exprs = [ ( - bigframes.core.nodes.AliasedRef( - ex.deref(source_id if (col_id == destination_id) else col_id), + nodes.AliasedRef( + ex.DerefOp( + self.node.field_by_id[ + ids.ColumnId( + source_id if (col_id == destination_id) else col_id + ) + ] + ), ids.ColumnId(col_id), ) ) @@ -293,14 +299,15 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue: ] else: # append case self_projection = ( - bigframes.core.nodes.AliasedRef.identity(ids.ColumnId(col_id)) + nodes.AliasedRef.identity(self.node.field_by_id[ids.ColumnId(col_id)]) for col_id in self.column_ids ) exprs = [ *self_projection, ( - bigframes.core.nodes.AliasedRef( - ex.deref(source_id), ids.ColumnId(destination_id) + nodes.AliasedRef( + ex.DerefOp(self.node.field_by_id[ids.ColumnId(source_id)]), + ids.ColumnId(destination_id), ) ), ] @@ -325,7 +332,7 @@ def create_constant( def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue: # This basically just drops and reorders columns - logically a no-op except as a final step selections = ( - bigframes.core.nodes.AliasedRef.identity(ids.ColumnId(col_id)) + nodes.AliasedRef.identity(self.node.field_by_id[ids.ColumnId(col_id)]) for col_id in column_ids ) return ArrayValue( @@ -343,8 +350,8 @@ def rename_columns(self, col_id_overrides: Mapping[str, str]) -> ArrayValue: nodes.SelectionNode( self.node, tuple( - nodes.AliasedRef(ex.DerefOp(old_id), ids.ColumnId(out_id)) - for old_id, out_id in zip(self.node.ids, output_ids) + nodes.AliasedRef(ex.DerefOp(field), ids.ColumnId(out_id)) + for field, out_id in zip(self.node.fields, output_ids) ), ) ) diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index fb5399b7cb..c83068eee1 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -40,7 +40,9 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult: - output_names = tuple((expression.DerefOp(id), id.sql) for id in request.node.ids) + output_names = tuple( + (expression.DerefOp(field), field.id.sql) for field in request.node.fields + ) result_node = nodes.ResultNode( request.node, output_cols=output_names, diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 91d1fa0d85..5ed786e236 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -65,7 +65,7 @@ def compile_raw( def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult: output_names = tuple( - (expression.DerefOp(id), id.sql) for id in request.node.ids + (expression.DerefOp(field), field.id.sql) for field in request.node.fields ) result_node = nodes.ResultNode( request.node, diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index afd290827d..f9d2281996 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -22,6 +22,7 @@ import pandas as pd +import bigframes.core.bigframe_node as bigframe_node import bigframes.core.identifiers as ids import bigframes.dtypes as dtypes import bigframes.operations @@ -342,7 +343,7 @@ def is_identity(self) -> bool: class DerefOp(Expression): """A variable expression representing an unbound variable.""" - id: ids.ColumnId + field: bigframe_node.Field @property def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: @@ -357,6 +358,11 @@ def nullable(self) -> bool: # Safe default, need to actually bind input schema to determine return True + @property + def id(self) -> ids.ColumnId: + # Safe default, need to actually bind input schema to determine + return self.field.id + def output_type( self, input_types: dict[ids.ColumnId, bigframes.dtypes.Dtype] ) -> dtypes.ExpressionType: diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 0fbfe7bd37..e9fafa515d 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -1081,8 +1081,8 @@ class AliasedRef(typing.NamedTuple): id: identifiers.ColumnId @classmethod - def identity(cls, id: identifiers.ColumnId) -> AliasedRef: - return cls(ex.DerefOp(id), id) + def identity(cls, field: Field) -> AliasedRef: + return cls(ex.DerefOp(field), field.id) def remap_vars( self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId] diff --git a/bigframes/core/rewrite/implicit_align.py b/bigframes/core/rewrite/implicit_align.py index 1989b1a543..c3804e6ae3 100644 --- a/bigframes/core/rewrite/implicit_align.py +++ b/bigframes/core/rewrite/implicit_align.py @@ -50,7 +50,7 @@ def get_expression_spec( # TODO: While we chain expression fragments from different nodes # we could further normalize with constant folding and other scalar expression rewrites expression: bigframes.core.expression.Expression = ( - bigframes.core.expression.DerefOp(id) + bigframes.core.expression.DerefOp(node.field_by_id[id]) ) curr_node = node while True: @@ -205,7 +205,7 @@ def pull_up_selection( var_renames = {} assert isinstance(new_node, bigframes.core.nodes.AdditiveNode) added_selections = tuple( - bigframes.core.nodes.AliasedRef.identity(field.id).remap_refs(var_renames) + bigframes.core.nodes.AliasedRef.identity(field).remap_refs(var_renames) for field in node.added_fields ) new_selection = child_selections + added_selections diff --git a/bigframes/core/rewrite/legacy_align.py b/bigframes/core/rewrite/legacy_align.py index 573a7026e4..2ad9bbbfc6 100644 --- a/bigframes/core/rewrite/legacy_align.py +++ b/bigframes/core/rewrite/legacy_align.py @@ -52,7 +52,9 @@ def from_node_span( cls, node: nodes.BigFrameNode, target: nodes.BigFrameNode ) -> SquashedSelect: if node == target: - selection = tuple((scalar_exprs.DerefOp(id), id) for id in node.ids) + selection = tuple( + (scalar_exprs.DerefOp(field), field.id) for field in node.fields + ) return cls(node, selection, None, ()) if isinstance(node, nodes.SelectionNode): @@ -229,7 +231,8 @@ def expand(self) -> nodes.BigFrameNode: if self.ordering: root = nodes.OrderByNode(child=root, by=self.ordering) selection = tuple( - bigframes.core.nodes.AliasedRef.identity(id) for _, id in self.columns + nodes.AliasedRef.identity(self.root.field_by_id[id]) + for _, id in self.columns ) return nodes.SelectionNode( child=nodes.ProjectionNode(child=root, assignments=self.columns), diff --git a/bigframes/core/rewrite/order.py b/bigframes/core/rewrite/order.py index 5b5fb10753..6d77107a70 100644 --- a/bigframes/core/rewrite/order.py +++ b/bigframes/core/rewrite/order.py @@ -32,7 +32,9 @@ def defer_order( else order ) if output_hidden_row_keys: - output_names = tuple((expression.DerefOp(id), id.sql) for id in new_child.ids) + output_names = tuple( + (expression.DerefOp(field), field.id.sql) for field in new_child.fields + ) else: output_names = root.output_cols return dataclasses.replace( @@ -102,7 +104,7 @@ def pull_up_order_inner( bigframes.core.ordering.TotalOrdering( ordering_value_columns=tuple(new_by), total_ordering_columns=frozenset( - map(lambda x: bigframes.core.expression.DerefOp(x), ids) + map(lambda x: expression.DerefOp(x), ids) ), ) ) @@ -203,7 +205,7 @@ def pull_up_order_inner( col: identifiers.ColumnId.unique() for col in unselected_order_cols } all_selections = node.input_output_pairs + tuple( - bigframes.core.nodes.AliasedRef(bigframes.core.expression.DerefOp(k), v) + bigframes.core.nodes.AliasedRef(expression.DerefOp(k), v) for k, v in new_selections.items() ) new_select_node = dataclasses.replace( @@ -287,9 +289,7 @@ def pull_order_concat( new_source, ((order_expression.scalar_expression, offsets_id),) ) else: - agg = bigframes.core.expression.NullaryAggregation( - agg_ops.RowNumberOp() - ) + agg = expression.NullaryAggregation(agg_ops.RowNumberOp()) window_spec = bigframes.core.window_spec.unbound( ordering=tuple(order.all_ordering_columns) ) @@ -297,14 +297,24 @@ def pull_order_concat( new_source, agg, window_spec, offsets_id ) new_source = bigframes.core.nodes.ProjectionNode( - new_source, ((bigframes.core.expression.const(i), table_id),) + new_source, ((expression.const(i), table_id),) + ) + offsets_id_alias = bigframes.core.nodes.AliasedRef.identity( + new_source.field_by_id[offsets_id] + ) + table_id_alias = bigframes.core.nodes.AliasedRef.identity( + new_source.field_by_id[table_id] ) selection = tuple( ( - bigframes.core.nodes.AliasedRef.identity(id) - for id in (*source.ids, table_id, offsets_id) + bigframes.core.nodes.AliasedRef.identity(field) + for field in source.fields ) ) + selection = selection + ( + offsets_id_alias, + table_id_alias, + ) new_source = bigframes.core.nodes.SelectionNode(new_source, selection) new_sources.append(new_source) @@ -409,7 +419,10 @@ def remove_order_strict( if result.ids != node.ids: return bigframes.core.nodes.SelectionNode( result, - tuple(bigframes.core.nodes.AliasedRef.identity(id) for id in node.ids), + tuple( + bigframes.core.nodes.AliasedRef.identity(field) + for field in node.fields + ), ) return result @@ -439,8 +452,8 @@ def rename_cols( result_node = bigframes.core.nodes.SelectionNode( node, tuple( - bigframes.core.nodes.AliasedRef.identity(id).remap_vars(mappings) - for id in node.ids + bigframes.core.nodes.AliasedRef.identity(field).remap_vars(mappings) + for field in node.fields ), ) diff --git a/bigframes/core/rewrite/pruning.py b/bigframes/core/rewrite/pruning.py index 1ecfb452ec..24a2b1d75e 100644 --- a/bigframes/core/rewrite/pruning.py +++ b/bigframes/core/rewrite/pruning.py @@ -111,8 +111,10 @@ def prune_selection_child( return selection new_children = [] for concat_node in child.child_nodes: - cc_ids = tuple(concat_node.ids) - sub_selection = tuple(nodes.AliasedRef.identity(cc_ids[i]) for i in indices) + cc_fields = tuple(concat_node.fields) + sub_selection = tuple( + nodes.AliasedRef.identity(cc_fields[i]) for i in indices + ) new_children.append(nodes.SelectionNode(concat_node, sub_selection)) return nodes.ConcatNode( children=tuple(new_children), output_ids=tuple(selection.ids) @@ -151,7 +153,11 @@ def prune_node( else: return nodes.SelectionNode( node, - tuple(nodes.AliasedRef.identity(id) for id in node.ids if id in ids), + tuple( + nodes.AliasedRef.identity(field) + for field in node.fields + if field.id in ids + ), ) diff --git a/bigframes/core/rewrite/slices.py b/bigframes/core/rewrite/slices.py index b8a003e061..cc8d9e5b36 100644 --- a/bigframes/core/rewrite/slices.py +++ b/bigframes/core/rewrite/slices.py @@ -133,9 +133,9 @@ def drop_cols( ) -> nodes.SelectionNode: # adding a whole node that redefines the schema is a lot of overhead, should do something more efficient selections = tuple( - nodes.AliasedRef(scalar_exprs.DerefOp(id), id) - for id in node.ids - if id not in drop_cols + nodes.AliasedRef(scalar_exprs.DerefOp(field), field.id) + for field in node.fields + if field.id not in drop_cols ) return nodes.SelectionNode(node, selections)