From be7d461702213420dc47a607c214ea4faf2a0637 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Tue, 4 Feb 2025 21:22:27 +0000 Subject: [PATCH 1/6] perf: Prune unused operations from sql --- bigframes/core/__init__.py | 24 +- bigframes/core/compile/compiler.py | 16 +- bigframes/core/nodes.py | 423 ++++++++++++----------- bigframes/core/rewrite/__init__.py | 2 + bigframes/core/rewrite/implicit_align.py | 41 +-- bigframes/core/rewrite/legacy_align.py | 6 +- bigframes/core/rewrite/order.py | 16 +- bigframes/core/rewrite/pruning.py | 207 +++++++++++ bigframes/core/rewrite/slices.py | 4 +- 9 files changed, 491 insertions(+), 248 deletions(-) create mode 100644 bigframes/core/rewrite/pruning.py diff --git a/bigframes/core/__init__.py b/bigframes/core/__init__.py index 5f64bf68dd..dc9b8e3b9b 100644 --- a/bigframes/core/__init__.py +++ b/bigframes/core/__init__.py @@ -304,18 +304,25 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue: if destination_id in self.column_ids: # Mutate case exprs = [ ( - ex.deref(source_id if (col_id == destination_id) else col_id), - ids.ColumnId(col_id), + bigframes.core.nodes.AliasedRef( + ex.deref(source_id if (col_id == destination_id) else col_id), + ids.ColumnId(col_id), + ) ) for col_id in self.column_ids ] else: # append case self_projection = ( - (ex.deref(col_id), ids.ColumnId(col_id)) for col_id in self.column_ids + bigframes.core.nodes.AliasedRef.identity(ids.ColumnId(col_id)) + for col_id in self.column_ids ) exprs = [ *self_projection, - (ex.deref(source_id), ids.ColumnId(destination_id)), + ( + bigframes.core.nodes.AliasedRef( + ex.deref(source_id), ids.ColumnId(destination_id) + ) + ), ] return ArrayValue( nodes.SelectionNode( @@ -337,7 +344,10 @@ def create_constant( def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue: # This basically just drops and reorders columns - logically a no-op except as a final step - selections = ((ex.deref(col_id), ids.ColumnId(col_id)) for col_id in column_ids) + selections = ( + bigframes.core.nodes.AliasedRef.identity(ids.ColumnId(col_id)) + for col_id in column_ids + ) return ArrayValue( nodes.SelectionNode( child=self.node, @@ -488,7 +498,9 @@ def prepare_join_names( nodes.SelectionNode( other.node, tuple( - (ex.deref(old_id), ids.ColumnId(new_id)) + bigframes.core.nodes.AliasedRef( + ex.deref(old_id), ids.ColumnId(new_id) + ) for old_id, new_id in r_mapping.items() ), ), diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index dca204401e..48dd7f2c5f 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -65,6 +65,7 @@ def compile_sql( node, ordering = rewrites.pull_up_order( node, order_root=True, ordered_joins=self.strict ) + node = rewrites.column_pruning(node) ir = self.compile_node(node) return ir.to_sql( order_by=ordering.all_ordering_columns, @@ -76,6 +77,7 @@ def compile_sql( node, _ = rewrites.pull_up_order( node, order_root=False, ordered_joins=self.strict ) + node = rewrites.column_pruning(node) ir = self.compile_node(node) return ir.to_sql(selections=output_ids) @@ -86,6 +88,7 @@ def compile_peek_sql(self, node: nodes.BigFrameNode, n_rows: int) -> str: node, _ = rewrites.pull_up_order( node, order_root=False, ordered_joins=self.strict ) + node = rewrites.column_pruning(node) return self.compile_node(node).to_sql(limit=n_rows, selections=ids) def compile_raw( @@ -97,6 +100,7 @@ def compile_raw( node = nodes.bottom_up(node, rewrites.rewrite_slice) node = nodes.top_down(node, rewrites.rewrite_timedelta_ops) node, ordering = rewrites.pull_up_order(node, ordered_joins=self.strict) + node = rewrites.column_pruning(node) ir = self.compile_node(node) sql = ir.to_sql() return sql, node.schema.to_bigquery(), ordering @@ -192,10 +196,12 @@ def compile_readtable(self, node: nodes.ReadTableNode): return self.compile_read_table_unordered(node.source, node.scan_list) def read_table_as_unordered_ibis( - self, source: nodes.BigqueryDataSource + self, + source: nodes.BigqueryDataSource, + scan_cols: typing.Sequence[str], ) -> ibis_types.Table: full_table_name = f"{source.table.project_id}.{source.table.dataset_id}.{source.table.table_id}" - used_columns = tuple(col.name for col in source.table.physical_schema) + used_columns = tuple(scan_cols) # Physical schema might include unused columns, unsupported datatypes like JSON physical_schema = ibis_bigquery.BigQuerySchema.to_ibis( list(i for i in source.table.physical_schema if i.name in used_columns) @@ -216,7 +222,9 @@ def read_table_as_unordered_ibis( def compile_read_table_unordered( self, source: nodes.BigqueryDataSource, scan: nodes.ScanList ): - ibis_table = self.read_table_as_unordered_ibis(source) + ibis_table = self.read_table_as_unordered_ibis( + source, scan_cols=[col.source_id for col in scan.items] + ) return compiled.UnorderedIR( ibis_table, tuple( @@ -291,7 +299,7 @@ def set_output_names( return nodes.SelectionNode( node, tuple( - (ex.DerefOp(old_id), ids.ColumnId(out_id)) + bigframes.core.nodes.AliasedRef(ex.DerefOp(old_id), ids.ColumnId(out_id)) for old_id, out_id in zip(node.ids, output_ids) ), ) diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index 085d52daa6..c64e18f8ed 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -20,7 +20,7 @@ import functools import itertools import typing -from typing import Callable, cast, Iterable, Mapping, Optional, Sequence, Tuple +from typing import Callable, cast, Iterable, Mapping, Optional, Sequence, Tuple, TypeVar import google.cloud.bigquery as bq @@ -44,6 +44,8 @@ COLUMN_SET = frozenset[bfet_ids.ColumnId] +Self = TypeVar("Self") + @dataclasses.dataclass(frozen=True) class Field: @@ -87,10 +89,17 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]: def row_count(self) -> typing.Optional[int]: return None + @abc.abstractmethod + def remap_vars( + self: Self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> Self: + """Remap defined (in this node only) variables.""" + ... + @abc.abstractmethod def remap_refs( - self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + self: Self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> Self: """Remap variable references""" ... @@ -100,6 +109,10 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: """The variables defined in this node (as opposed to by child nodes).""" ... + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset() + @functools.cached_property def session(self): sessions = [] @@ -248,18 +261,11 @@ def planning_complexity(self) -> int: @abc.abstractmethod def transform_children( - self, t: Callable[[BigFrameNode], BigFrameNode] - ) -> BigFrameNode: + self: Self, t: Callable[[BigFrameNode], BigFrameNode] + ) -> Self: """Apply a function to each child node.""" ... - @abc.abstractmethod - def remap_vars( - self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: - """Remap defined (in this node only) variables.""" - ... - @property def defines_namespace(self) -> bool: """ @@ -269,16 +275,6 @@ def defines_namespace(self) -> bool: """ return False - @functools.cached_property - def defined_variables(self) -> set[str]: - """Full set of variables defined in the namespace, even if not selected.""" - self_defined_variables = set(self.schema.names) - if self.defines_namespace: - return self_defined_variables - return self_defined_variables.union( - *(child.defined_variables for child in self.child_nodes) - ) - def get_type(self, id: bfet_ids.ColumnId) -> bigframes.dtypes.Dtype: return self._dtype_lookup[id] @@ -286,9 +282,6 @@ def get_type(self, id: bfet_ids.ColumnId) -> bigframes.dtypes.Dtype: def _dtype_lookup(self): return {field.id: field.dtype for field in self.fields} - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - return self.transform_children(lambda x: x.prune(used_cols)) - class AdditiveNode: """Definition of additive - if you drop added_fields, you end up with the descendent. @@ -336,7 +329,7 @@ def explicitly_ordered(self) -> bool: def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] - ) -> BigFrameNode: + ) -> UnaryNode: transformed = dataclasses.replace(self, child=t(self.child)) if self == transformed: # reusing existing object speeds up eq, and saves a small amount of memory @@ -406,12 +399,18 @@ def row_count(self) -> typing.Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return () + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset() + def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> SliceNode: return self - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> SliceNode: return self @@ -483,6 +482,10 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return (self.indicator_col,) + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset({self.left_col.id, self.right_col.id}) + @property def additive_base(self) -> BigFrameNode: return self.left_child @@ -490,9 +493,7 @@ def additive_base(self) -> BigFrameNode: def replace_additive_base(self, node: BigFrameNode): return dataclasses.replace(self, left_child=node) - def transform_children( - self, t: Callable[[BigFrameNode], BigFrameNode] - ) -> BigFrameNode: + def transform_children(self, t: Callable[[BigFrameNode], BigFrameNode]) -> InNode: transformed = dataclasses.replace( self, left_child=t(self.left_child), right_child=t(self.right_child) ) @@ -501,17 +502,16 @@ def transform_children( return self return transformed - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - return self - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> InNode: return dataclasses.replace( self, indicator_col=mappings.get(self.indicator_col, self.indicator_col) ) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> InNode: return dataclasses.replace(self, left_col=self.left_col.remap_column_refs(mappings, allow_partial_bindings=True), right_col=self.right_col.remap_column_refs(mappings, allow_partial_bindings=True)) # type: ignore @@ -574,9 +574,20 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return () - def transform_children( - self, t: Callable[[BigFrameNode], BigFrameNode] - ) -> BigFrameNode: + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset( + itertools.chain.from_iterable( + (*l_cond.column_references, *r_cond.column_references) + for l_cond, r_cond in self.conditions + ) + ) + + @property + def consumed_ids(self) -> COLUMN_SET: + return frozenset(*self.ids, *self.referenced_ids) + + def transform_children(self, t: Callable[[BigFrameNode], BigFrameNode]) -> JoinNode: transformed = dataclasses.replace( self, left_child=t(self.left_child), right_child=t(self.right_child) ) @@ -585,21 +596,14 @@ def transform_children( return self return transformed - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - # If this is a cross join, make sure to select at least one column from each side - condition_cols = used_cols.union( - map(lambda x: x.id, itertools.chain.from_iterable(self.conditions)) - ) - return self.transform_children( - lambda x: x.prune(frozenset([*condition_cols, *used_cols])) - ) - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> JoinNode: return self - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> JoinNode: new_conds = tuple( ( l_cond.remap_column_refs(mappings, allow_partial_bindings=True), @@ -665,7 +669,7 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] - ) -> BigFrameNode: + ) -> ConcatNode: transformed = dataclasses.replace( self, children=tuple(t(child) for child in self.children) ) @@ -674,17 +678,15 @@ def transform_children( return self return transformed - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - # TODO: Make concat prunable, probably by redefining - return self - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> ConcatNode: new_ids = tuple(mappings.get(id, id) for id in self.output_ids) return dataclasses.replace(self, output_ids=new_ids) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> ConcatNode: return self @@ -735,25 +737,23 @@ def defines_namespace(self) -> bool: def transform_children( self, t: Callable[[BigFrameNode], BigFrameNode] - ) -> BigFrameNode: + ) -> FromRangeNode: transformed = dataclasses.replace(self, start=t(self.start), end=t(self.end)) if self == transformed: # reusing existing object speeds up eq, and saves a small amount of memory return self return transformed - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - # TODO: Make FromRangeNode prunable (or convert to other node types) - return self - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> FromRangeNode: return dataclasses.replace( self, output_id=mappings.get(self.output_id, self.output_id) ) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> FromRangeNode: return self @@ -774,9 +774,7 @@ def fast_offsets(self) -> bool: def fast_ordered_limit(self) -> bool: return False - def transform_children( - self, t: Callable[[BigFrameNode], BigFrameNode] - ) -> BigFrameNode: + def transform_children(self, t: Callable[[BigFrameNode], BigFrameNode]) -> LeafNode: return self @@ -785,6 +783,9 @@ class ScanItem(typing.NamedTuple): dtype: bigframes.dtypes.Dtype # Might be multiple logical types for a given physical source type source_id: str # Flexible enough for both local data and bq data + def with_id(self, id: bfet_ids.ColumnId) -> ScanItem: + return ScanItem(id, self.dtype, self.source_id) + @dataclasses.dataclass(frozen=True) class ScanList: @@ -841,25 +842,9 @@ def row_count(self) -> typing.Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return tuple(item.id for item in self.fields) - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - # Don't preoduce empty scan list no matter what, will result in broken sql syntax - # TODO: Handle more elegantly - new_scan_list = ScanList( - tuple(item for item in self.scan_list.items if item.id in used_cols) - or (self.scan_list.items[0],) - ) - return ReadLocalNode( - self.feather_bytes, - self.data_schema, - self.n_rows, - new_scan_list, - self.offsets_col, - self.session, - ) - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> ReadLocalNode: new_scan_list = ScanList( tuple( ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id) @@ -875,7 +860,9 @@ def remap_vars( self, scan_list=new_scan_list, offsets_col=new_offsets_col ) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> ReadLocalNode: return self @@ -1003,16 +990,9 @@ def row_count(self) -> typing.Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return tuple(item.id for item in self.scan_list.items) - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - new_scan_list = ScanList( - tuple(item for item in self.scan_list.items if item.id in used_cols) - or (self.scan_list.items[0],) - ) - return dataclasses.replace(self, scan_list=new_scan_list) - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> ReadTableNode: new_scan_list = ScanList( tuple( ScanItem(mappings.get(item.id, item.id), item.dtype, item.source_id) @@ -1021,7 +1001,9 @@ def remap_vars( ) return dataclasses.replace(self, scan_list=new_scan_list) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> ReadTableNode: return self def with_order_cols(self): @@ -1089,6 +1071,10 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return (self.col_id,) + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset() + @property def added_fields(self) -> Tuple[Field, ...]: return (Field(self.col_id, bigframes.dtypes.INT_DTYPE),) @@ -1097,22 +1083,17 @@ def added_fields(self) -> Tuple[Field, ...]: def additive_base(self) -> BigFrameNode: return self.child - def replace_additive_base(self, node: BigFrameNode): + def replace_additive_base(self, node: BigFrameNode) -> PromoteOffsetsNode: return dataclasses.replace(self, child=node) - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - if self.col_id not in used_cols: - return self.child.prune(used_cols) - else: - new_used = used_cols.difference([self.col_id]) - return self.transform_children(lambda x: x.prune(new_used)) - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> PromoteOffsetsNode: return dataclasses.replace(self, col_id=mappings.get(self.col_id, self.col_id)) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> PromoteOffsetsNode: return self @@ -1136,17 +1117,22 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return () - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - consumed_ids = used_cols.union(self.predicate.column_references) - pruned_child = self.child.prune(consumed_ids) - return FilterNode(pruned_child, self.predicate) + @property + def consumed_ids(self) -> COLUMN_SET: + return frozenset(self.ids) | self.referenced_ids + + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset(self.predicate.column_references) def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> FilterNode: return self - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> FilterNode: return dataclasses.replace( self, predicate=self.predicate.remap_column_refs( @@ -1183,20 +1169,24 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return () - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - ordering_cols = itertools.chain.from_iterable( - map(lambda x: x.referenced_columns, self.by) + @property + def consumed_ids(self) -> COLUMN_SET: + return frozenset(self.ids) | self.referenced_ids + + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset( + itertools.chain.from_iterable(map(lambda x: x.referenced_columns, self.by)) ) - consumed_ids = used_cols.union(ordering_cols) - pruned_child = self.child.prune(consumed_ids) - return OrderByNode(pruned_child, self.by) def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> OrderByNode: return self - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> OrderByNode: all_refs = set( itertools.chain.from_iterable(map(lambda x: x.referenced_columns, self.by)) ) @@ -1233,20 +1223,43 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return () + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset() + def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> ReversedNode: return self - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> ReversedNode: return self +class AliasedRef(typing.NamedTuple): + ref: ex.DerefOp + id: bfet_ids.ColumnId + + @classmethod + def identity(cls, id: bfet_ids.ColumnId) -> AliasedRef: + return cls(ex.DerefOp(id), id) + + def remap_vars( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> AliasedRef: + return AliasedRef(self.ref, mappings.get(self.ref.id, self.ref.id)) + + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> AliasedRef: + return AliasedRef(ex.DerefOp(mappings.get(self.ref.id, self.ref.id)), self.id) + + @dataclasses.dataclass(frozen=True, eq=False) class SelectionNode(UnaryNode): - input_output_pairs: typing.Tuple[ - typing.Tuple[ex.DerefOp, bigframes.core.identifiers.ColumnId], ... - ] + input_output_pairs: Tuple[AliasedRef, ...] def _validate(self): for ref, _ in self.input_output_pairs: @@ -1280,33 +1293,26 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return tuple(id for _, id in self.input_output_pairs) - def get_id_mapping(self) -> dict[bfet_ids.ColumnId, bfet_ids.ColumnId]: - return {ref.id: out_id for ref, out_id in self.input_output_pairs} - - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - pruned_selections = ( - tuple( - select for select in self.input_output_pairs if select[1] in used_cols - ) - or self.input_output_pairs[:1] - ) - consumed_ids = frozenset(i[0].id for i in pruned_selections) + @property + def consumed_ids(self) -> COLUMN_SET: + return frozenset(ref.id for ref, id in self.input_output_pairs) - pruned_child = self.child.prune(consumed_ids) - return SelectionNode(pruned_child, pruned_selections) + def get_id_mapping(self) -> dict[bfet_ids.ColumnId, bfet_ids.ColumnId]: + return {ref.id: id for ref, id in self.input_output_pairs} def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: - new_pairs = tuple( - (ref, mappings.get(id, id)) for ref, id in self.input_output_pairs + ) -> SelectionNode: + new_fields = tuple( + item.remap_vars(mappings) for item in self.input_output_pairs ) - return dataclasses.replace(self, input_output_pairs=new_pairs) + return dataclasses.replace(self, input_output_pairs=new_fields) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> SelectionNode: new_fields = tuple( - (ex.remap_column_refs(mappings, allow_partial_bindings=True), id) - for ex, id in self.input_output_pairs + item.remap_refs(mappings) for item in self.input_output_pairs ) return dataclasses.replace(self, input_output_pairs=new_fields) # type: ignore @@ -1353,30 +1359,38 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return tuple(id for _, id in self.assignments) + @property + def consumed_ids(self) -> COLUMN_SET: + return frozenset( + itertools.chain.from_iterable( + i[0].column_references for i in self.assignments + ) + ) + + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset( + itertools.chain.from_iterable( + ex.column_references for ex, id in self.assignments + ) + ) + @property def additive_base(self) -> BigFrameNode: return self.child - def replace_additive_base(self, node: BigFrameNode): + def replace_additive_base(self, node: BigFrameNode) -> ProjectionNode: return dataclasses.replace(self, child=node) - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - pruned_assignments = tuple(i for i in self.assignments if i[1] in used_cols) - if len(pruned_assignments) == 0: - return self.child.prune(used_cols) - consumed_ids = itertools.chain.from_iterable( - i[0].column_references for i in pruned_assignments - ) - pruned_child = self.child.prune(used_cols.union(consumed_ids)) - return ProjectionNode(pruned_child, pruned_assignments) - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> ProjectionNode: new_fields = tuple((ex, mappings.get(id, id)) for ex, id in self.assignments) return dataclasses.replace(self, assignments=new_fields) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> ProjectionNode: new_fields = tuple( (ex.remap_column_refs(mappings, allow_partial_bindings=True), id) for ex, id in self.assignments @@ -1418,16 +1432,18 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return (self.col_id,) + @property + def consumed_ids(self) -> COLUMN_SET: + return frozenset() + def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> RowCountNode: return dataclasses.replace(self, col_id=mappings.get(self.col_id, self.col_id)) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): - return self - - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - # TODO: Handle row count pruning + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> RowCountNode: return self @@ -1487,33 +1503,31 @@ def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return tuple(id for _, id in self.aggregations) @property - def has_ordered_ops(self) -> bool: - return not all( - aggregate.op.order_independent for aggregate, _ in self.aggregations - ) - - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: + def consumed_ids(self) -> COLUMN_SET: by_ids = (ref.id for ref in self.by_column_ids) - pruned_aggs = ( - tuple(agg for agg in self.aggregations if agg[1] in used_cols) - or self.aggregations[:1] - ) agg_inputs = itertools.chain.from_iterable( - agg.column_references for agg, _ in pruned_aggs + agg.column_references for agg, _ in self.aggregations ) - consumed_ids = frozenset(itertools.chain(by_ids, agg_inputs)) - pruned_child = self.child.prune(consumed_ids) - return AggregateNode( - pruned_child, pruned_aggs, self.by_column_ids, dropna=self.dropna + order_ids = itertools.chain.from_iterable( + part.scalar_expression.column_references for part in self.order_by + ) + return frozenset(itertools.chain(by_ids, agg_inputs, order_ids)) + + @property + def has_ordered_ops(self) -> bool: + return not all( + aggregate.op.order_independent for aggregate, _ in self.aggregations ) def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> AggregateNode: new_aggs = tuple((agg, mappings.get(id, id)) for agg, id in self.aggregations) return dataclasses.replace(self, aggregations=new_aggs) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> AggregateNode: new_aggs = tuple( (agg.remap_column_refs(mappings, allow_partial_bindings=True), id) for agg, id in self.aggregations @@ -1578,6 +1592,20 @@ def added_field(self) -> Field: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return (self.output_name,) + @property + def consumed_ids(self) -> COLUMN_SET: + return frozenset( + set(self.ids).difference([self.output_name]).union(self.referenced_ids) + ) + + @property + def referenced_ids(self) -> COLUMN_SET: + return ( + frozenset() + .union(self.expression.column_references) + .union(self.window_spec.all_referenced_columns) + ) + @property def inherits_order(self) -> bool: # does the op both use ordering at all? and if so, can it inherit order? @@ -1590,27 +1618,19 @@ def inherits_order(self) -> bool: def additive_base(self) -> BigFrameNode: return self.child - def replace_additive_base(self, node: BigFrameNode): + def replace_additive_base(self, node: BigFrameNode) -> WindowOpNode: return dataclasses.replace(self, child=node) - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - if self.output_name not in used_cols: - return self.child.prune(used_cols) - consumed_ids = ( - used_cols.difference([self.output_name]) - .union(self.expression.column_references) - .union(self.window_spec.all_referenced_columns) - ) - return self.transform_children(lambda x: x.prune(consumed_ids)) - def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> WindowOpNode: return dataclasses.replace( self, output_name=mappings.get(self.output_name, self.output_name) ) - def remap_refs(self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId]): + def remap_refs( + self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] + ) -> WindowOpNode: return dataclasses.replace( self, expression=self.expression.remap_column_refs( @@ -1646,14 +1666,18 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return () + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset() + def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> RandomSampleNode: return self def remap_refs( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> RandomSampleNode: return self @@ -1703,21 +1727,20 @@ def row_count(self) -> Optional[int]: def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]: return (self.offsets_col,) if (self.offsets_col is not None) else () - def prune(self, used_cols: COLUMN_SET) -> BigFrameNode: - # Cannot prune explode op - consumed_ids = used_cols.union(ref.id for ref in self.column_ids) - return self.transform_children(lambda x: x.prune(consumed_ids)) + @property + def referenced_ids(self) -> COLUMN_SET: + return frozenset(ref.id for ref in self.column_ids) def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> ExplodeNode: if (self.offsets_col is not None) and self.offsets_col in mappings: return dataclasses.replace(self, offsets_col=mappings[self.offsets_col]) return self def remap_refs( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] - ) -> BigFrameNode: + ) -> ExplodeNode: new_ids = tuple(id.remap_column_refs(mappings) for id in self.column_ids) return dataclasses.replace(self, column_ids=new_ids) # type: ignore diff --git a/bigframes/core/rewrite/__init__.py b/bigframes/core/rewrite/__init__.py index f93186bf36..bf93fa51b6 100644 --- a/bigframes/core/rewrite/__init__.py +++ b/bigframes/core/rewrite/__init__.py @@ -17,6 +17,7 @@ from bigframes.core.rewrite.legacy_align import legacy_join_as_projection from bigframes.core.rewrite.operators import rewrite_timedelta_ops from bigframes.core.rewrite.order import pull_up_order +from bigframes.core.rewrite.pruning import column_pruning from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice __all__ = [ @@ -27,4 +28,5 @@ "pullup_limit_from_slice", "remap_variables", "pull_up_order", + "column_pruning", ] diff --git a/bigframes/core/rewrite/implicit_align.py b/bigframes/core/rewrite/implicit_align.py index 1b864fb919..7e2d086c82 100644 --- a/bigframes/core/rewrite/implicit_align.py +++ b/bigframes/core/rewrite/implicit_align.py @@ -113,7 +113,7 @@ def try_row_join( r_node, r_selection = pull_up_selection( r_node, stop=divergent_node, rename_vars=True ) # Rename only right vars to avoid collisions with left vars - combined_selection = (*l_selection, *r_selection) + combined_selection = l_selection + r_selection def _linearize_trees( base_tree: bigframes.core.nodes.BigFrameNode, @@ -139,10 +139,7 @@ def pull_up_selection( rename_vars: bool = False, ) -> Tuple[ bigframes.core.nodes.BigFrameNode, - Tuple[ - Tuple[bigframes.core.expression.DerefOp, bigframes.core.identifiers.ColumnId], - ..., - ], + Tuple[bigframes.core.nodes.AliasedRef, ...], ]: """Remove all selection nodes above the base node. Returns stripped tree. @@ -157,8 +154,7 @@ def pull_up_selection( """ if node == stop: # base case return node, tuple( - (bigframes.core.expression.DerefOp(field.id), field.id) - for field in node.fields + bigframes.core.nodes.AliasedRef.identity(field.id) for field in node.fields ) # InNode needs special handling, as its a binary node, but row identity is from left side only. # TODO: Merge code with unary op paths @@ -179,18 +175,21 @@ def pull_up_selection( {node.indicator_col: bigframes.core.identifiers.ColumnId.unique()} ), ) - added_selection = ( - bigframes.core.expression.DerefOp(new_in_node.indicator_col), - node.indicator_col, + added_selection = tuple( + ( + bigframes.core.nodes.AliasedRef( + bigframes.core.expression.DerefOp(new_in_node.indicator_col), + node.indicator_col, + ), + ) ) - new_selection = (*child_selections, added_selection) + new_selection = child_selections + added_selection return new_in_node, new_selection if isinstance(node, bigframes.core.nodes.AdditiveNode): child_node, child_selections = pull_up_selection( node.additive_base, stop, rename_vars=rename_vars ) - mapping = {out: ref.id for ref, out in child_selections} new_node: bigframes.core.nodes.BigFrameNode = node.replace_additive_base( child_node ) @@ -204,28 +203,20 @@ def pull_up_selection( else: var_renames = {} assert isinstance(new_node, bigframes.core.nodes.AdditiveNode) - added_selections = ( - ( - bigframes.core.expression.DerefOp(var_renames.get(field.id, field.id)), - field.id, - ) + added_selections = tuple( + bigframes.core.nodes.AliasedRef.identity(field.id).remap_vars(var_renames) for field in node.added_fields ) - new_selection = (*child_selections, *added_selections) + new_selection = child_selections + added_selections return new_node, new_selection elif isinstance(node, bigframes.core.nodes.SelectionNode): child_node, child_selections = pull_up_selection( node.child, stop, rename_vars=rename_vars ) mapping = {out: ref.id for ref, out in child_selections} - new_selection = tuple( - ( - bigframes.core.expression.DerefOp(mapping[ref.id]), - out, - ) - for ref, out in node.input_output_pairs + return child_node, tuple( + ref.remap_refs(mapping) for ref in node.input_output_pairs ) - return child_node, new_selection raise ValueError(f"Couldn't pull up select from node: {node}") diff --git a/bigframes/core/rewrite/legacy_align.py b/bigframes/core/rewrite/legacy_align.py index 05641130fb..573a7026e4 100644 --- a/bigframes/core/rewrite/legacy_align.py +++ b/bigframes/core/rewrite/legacy_align.py @@ -57,7 +57,7 @@ def from_node_span( if isinstance(node, nodes.SelectionNode): return cls.from_node_span(node.child, target).select( - node.input_output_pairs + tuple(node.input_output_pairs) ) elif isinstance(node, nodes.ProjectionNode): return cls.from_node_span(node.child, target).project(node.assignments) @@ -228,7 +228,9 @@ def expand(self) -> nodes.BigFrameNode: root = nodes.FilterNode(child=root, predicate=self.predicate) if self.ordering: root = nodes.OrderByNode(child=root, by=self.ordering) - selection = tuple((scalar_exprs.DerefOp(id), id) for _, id in self.columns) + selection = tuple( + bigframes.core.nodes.AliasedRef.identity(id) for _, id in self.columns + ) return nodes.SelectionNode( child=nodes.ProjectionNode(child=root, assignments=self.columns), input_output_pairs=selection, diff --git a/bigframes/core/rewrite/order.py b/bigframes/core/rewrite/order.py index 3f8c409b76..18e5004e1d 100644 --- a/bigframes/core/rewrite/order.py +++ b/bigframes/core/rewrite/order.py @@ -180,14 +180,10 @@ def pull_up_order_inner( col: bigframes.core.ids.ColumnId.unique() for col in unselected_order_cols } - all_selections = ( - *node.input_output_pairs, - *( - (bigframes.core.expression.DerefOp(k), v) - for k, v in new_selections.items() - ), + all_selections = node.input_output_pairs + tuple( + bigframes.core.nodes.AliasedRef(bigframes.core.expression.DerefOp(k), v) + for k, v in new_selections.items() ) - new_select_node = dataclasses.replace( node, child=child_result, input_output_pairs=all_selections ) @@ -288,7 +284,7 @@ def pull_order_concat( ) selection = tuple( ( - (bigframes.core.expression.DerefOp(id), id) + bigframes.core.nodes.AliasedRef.identity(id) for id in (*source.ids, table_id, offsets_id) ) ) @@ -396,7 +392,7 @@ def remove_order_strict( if result.ids != node.ids: return bigframes.core.nodes.SelectionNode( result, - tuple((bigframes.core.expression.DerefOp(id), id) for id in node.ids), + tuple(bigframes.core.nodes.AliasedRef.identity(id) for id in node.ids), ) return result @@ -428,7 +424,7 @@ def rename_cols( result_node = bigframes.core.nodes.SelectionNode( node, tuple( - (bigframes.core.expression.DerefOp(id), mappings.get(id, id)) + bigframes.core.nodes.AliasedRef.identity(id).remap_vars(mappings) for id in node.ids ), ) diff --git a/bigframes/core/rewrite/pruning.py b/bigframes/core/rewrite/pruning.py new file mode 100644 index 0000000000..3063c95afb --- /dev/null +++ b/bigframes/core/rewrite/pruning.py @@ -0,0 +1,207 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import functools +from typing import AbstractSet, Iterable, TypeVar + +import bigframes.core.identifiers +import bigframes.core.nodes + + +def column_pruning( + root: bigframes.core.nodes.BigFrameNode, +) -> bigframes.core.nodes.BigFrameNode: + return bigframes.core.nodes.top_down(root, prune_columns) + + +def to_fixed(max_iterations: int = 100): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + previous_result = None + current_result = func(*args, **kwargs) + attempts = 1 + + while attempts < max_iterations: + if current_result == previous_result: + return current_result + previous_result = current_result + current_result = func(current_result) + attempts += 1 + + return current_result + + return wrapper + + return decorator + + +@to_fixed(max_iterations=100) +def prune_columns(node: bigframes.core.nodes.BigFrameNode): + if isinstance(node, bigframes.core.nodes.SelectionNode): + result = prune_selection_child(node) + elif isinstance(node, bigframes.core.nodes.AggregateNode): + result = node.replace_child(prune_node(node.child, node.consumed_ids)) + elif isinstance(node, bigframes.core.nodes.InNode): + result = dataclasses.replace( + node, + right_child=prune_node(node.right_child, frozenset([node.right_col.id])), + ) + else: + result = node + return result + + +def prune_selection_child( + selection: bigframes.core.nodes.SelectionNode, +) -> bigframes.core.nodes.BigFrameNode: + child = selection.child + + # Important to check this first + if list(selection.ids) == list(child.ids): + return child + + if isinstance(child, bigframes.core.nodes.SelectionNode): + return selection.remap_refs( + {id: ref.id for ref, id in child.input_output_pairs} + ).replace_child(child.child) + elif isinstance(child, bigframes.core.nodes.AdditiveNode): + if not set(field.id for field in child.added_fields) & selection.consumed_ids: + return selection.replace_child(child.additive_base) + return selection.replace_child( + child.replace_additive_base( + prune_node( + child.additive_base, selection.consumed_ids | child.referenced_ids + ) + ) + ) + elif isinstance(child, bigframes.core.nodes.ConcatNode): + indices = [ + list(child.ids).index(ref.id) for ref, _ in selection.input_output_pairs + ] + new_children = [] + for concat_node in child.child_nodes: + cc_ids = tuple(concat_node.ids) + sub_selection = tuple( + bigframes.core.nodes.AliasedRef.identity(cc_ids[i]) for i in indices + ) + new_children.append( + bigframes.core.nodes.SelectionNode(concat_node, sub_selection) + ) + return bigframes.core.nodes.ConcatNode( + children=tuple(new_children), output_ids=tuple(selection.ids) + ) + # Nodes that pass through input columns + elif isinstance( + child, + ( + bigframes.core.nodes.RandomSampleNode, + bigframes.core.nodes.ReversedNode, + bigframes.core.nodes.OrderByNode, + bigframes.core.nodes.FilterNode, + bigframes.core.nodes.SliceNode, + bigframes.core.nodes.JoinNode, + bigframes.core.nodes.ExplodeNode, + ), + ): + ids = selection.consumed_ids | child.referenced_ids + return selection.replace_child( + child.transform_children(lambda x: prune_node(x, ids)) + ) + elif isinstance(child, bigframes.core.nodes.AggregateNode): + return selection.replace_child(prune_aggregate(child, selection.consumed_ids)) + elif isinstance(child, bigframes.core.nodes.LeafNode): + return selection.replace_child(prune_leaf(child, selection.consumed_ids)) + return selection + + +def prune_node( + node: bigframes.core.nodes.BigFrameNode, + ids: AbstractSet[bigframes.core.ids.ColumnId], +): + # This clause is important, ensures idempotency, so can reach fixed point + if not (set(node.ids) - ids): + return node + else: + return bigframes.core.nodes.SelectionNode( + node, + tuple( + bigframes.core.nodes.AliasedRef.identity(id) + for id in node.ids + if id in ids + ), + ) + + +def prune_aggregate( + node: bigframes.core.nodes.AggregateNode, + used_cols: AbstractSet[bigframes.core.ids.ColumnId], +) -> bigframes.core.nodes.AggregateNode: + pruned_aggs = tuple(agg for agg in node.aggregations if agg[1] in used_cols) + return dataclasses.replace(node, aggregations=pruned_aggs) + + +@functools.singledispatch +def prune_leaf( + node: bigframes.core.nodes.BigFrameNode, + used_cols: AbstractSet[bigframes.core.ids.ColumnId], +): + ... + + +@prune_leaf.register +def prune_readlocal( + node: bigframes.core.nodes.ReadLocalNode, + selection: AbstractSet[bigframes.core.ids.ColumnId], +) -> bigframes.core.nodes.ReadLocalNode: + new_scan_list = filter_scanlist(node.scan_list, selection) + return dataclasses.replace( + node, + scan_list=new_scan_list, + offsets_col=node.offsets_col if (node.offsets_col in selection) else None, + ) + + +@prune_leaf.register +def prune_readtable( + node: bigframes.core.nodes.ReadTableNode, + selection: AbstractSet[bigframes.core.ids.ColumnId], +) -> bigframes.core.nodes.ReadTableNode: + new_scan_list = filter_scanlist(node.scan_list, selection) + return dataclasses.replace(node, scan_list=new_scan_list) + + +def filter_scanlist( + scanlist: bigframes.core.nodes.ScanList, + ids: AbstractSet[bigframes.core.ids.ColumnId], +): + result = bigframes.core.nodes.ScanList( + tuple(item for item in scanlist.items if item.id in ids) + ) + if len(result.items) == 0: + # We need to select something, or stuff breaks + result = bigframes.core.nodes.ScanList(scanlist.items[:1]) + return result + + +T = TypeVar("T") + + +def dedupe(items: Iterable[T]) -> Iterable[T]: + seen = set() + + for item in items: + if item not in seen: + seen.add(item) + yield item diff --git a/bigframes/core/rewrite/slices.py b/bigframes/core/rewrite/slices.py index 102ffcf773..87a7720e2f 100644 --- a/bigframes/core/rewrite/slices.py +++ b/bigframes/core/rewrite/slices.py @@ -120,7 +120,9 @@ def drop_cols( ) -> nodes.SelectionNode: # adding a whole node that redefines the schema is a lot of overhead, should do something more efficient selections = tuple( - (scalar_exprs.DerefOp(id), id) for id in node.ids if id not in drop_cols + nodes.AliasedRef(scalar_exprs.DerefOp(id), id) + for id in node.ids + if id not in drop_cols ) return nodes.SelectionNode(node, selections) From 3ec3158c1c80454ec1b8e4053ec7bf1eab3988b1 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 6 Feb 2025 07:01:23 +0000 Subject: [PATCH 2/6] chore: Add __init__.py to functions test modules --- tests/system/large/functions/__init__.py | 0 tests/system/small/functions/__init__.py | 13 +++++++++++++ tests/unit/functions/__init__.py | 13 +++++++++++++ 3 files changed, 26 insertions(+) create mode 100644 tests/system/large/functions/__init__.py create mode 100644 tests/system/small/functions/__init__.py create mode 100644 tests/unit/functions/__init__.py diff --git a/tests/system/large/functions/__init__.py b/tests/system/large/functions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/system/small/functions/__init__.py b/tests/system/small/functions/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/tests/system/small/functions/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit/functions/__init__.py b/tests/unit/functions/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/tests/unit/functions/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From e01d2ce2a0ea06781a1a29da708f5ed611dcf487 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 6 Feb 2025 07:46:31 +0000 Subject: [PATCH 3/6] fix ref remapping --- bigframes/core/compile/compiled.py | 2 +- bigframes/core/compile/compiler.py | 1 + bigframes/core/nodes.py | 2 +- bigframes/core/rewrite/implicit_align.py | 3 ++- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/bigframes/core/compile/compiled.py b/bigframes/core/compile/compiled.py index 906bdb1f0d..93be998b5b 100644 --- a/bigframes/core/compile/compiled.py +++ b/bigframes/core/compile/compiled.py @@ -184,7 +184,7 @@ def _to_ibis_expr( # Special case for empty tables, since we can't create an empty # projection. if not self._columns: - return bigframes_vendored.ibis.memtable([]) + return self._table.select([bigframes_vendored.ibis.literal(1)]) table = self._table.select(self._columns) if fraction is not None: diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index 48dd7f2c5f..b614a150ba 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -62,6 +62,7 @@ def compile_sql( if ordered: node, limit = rewrites.pullup_limit_from_slice(node) node = nodes.bottom_up(node, rewrites.rewrite_slice) + # TODO: Extract out CTEs node, ordering = rewrites.pull_up_order( node, order_root=True, ordered_joins=self.strict ) diff --git a/bigframes/core/nodes.py b/bigframes/core/nodes.py index c64e18f8ed..88e084d79c 100644 --- a/bigframes/core/nodes.py +++ b/bigframes/core/nodes.py @@ -1249,7 +1249,7 @@ def identity(cls, id: bfet_ids.ColumnId) -> AliasedRef: def remap_vars( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] ) -> AliasedRef: - return AliasedRef(self.ref, mappings.get(self.ref.id, self.ref.id)) + return AliasedRef(self.ref, mappings.get(self.id, self.id)) def remap_refs( self, mappings: Mapping[bfet_ids.ColumnId, bfet_ids.ColumnId] diff --git a/bigframes/core/rewrite/implicit_align.py b/bigframes/core/rewrite/implicit_align.py index 7e2d086c82..1989b1a543 100644 --- a/bigframes/core/rewrite/implicit_align.py +++ b/bigframes/core/rewrite/implicit_align.py @@ -190,6 +190,7 @@ def pull_up_selection( child_node, child_selections = pull_up_selection( node.additive_base, stop, rename_vars=rename_vars ) + mapping = {out: ref.id for ref, out in child_selections} new_node: bigframes.core.nodes.BigFrameNode = node.replace_additive_base( child_node ) @@ -204,7 +205,7 @@ def pull_up_selection( var_renames = {} assert isinstance(new_node, bigframes.core.nodes.AdditiveNode) added_selections = tuple( - bigframes.core.nodes.AliasedRef.identity(field.id).remap_vars(var_renames) + bigframes.core.nodes.AliasedRef.identity(field.id).remap_refs(var_renames) for field in node.added_fields ) new_selection = child_selections + added_selections From 64dfe19149b8baa61301ae93d015e476be8a924c Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 6 Feb 2025 18:55:32 +0000 Subject: [PATCH 4/6] revert pruning physical schema passed to ibis --- bigframes/core/compile/compiler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bigframes/core/compile/compiler.py b/bigframes/core/compile/compiler.py index b614a150ba..ff5f1d61c8 100644 --- a/bigframes/core/compile/compiler.py +++ b/bigframes/core/compile/compiler.py @@ -202,23 +202,22 @@ def read_table_as_unordered_ibis( scan_cols: typing.Sequence[str], ) -> ibis_types.Table: full_table_name = f"{source.table.project_id}.{source.table.dataset_id}.{source.table.table_id}" - used_columns = tuple(scan_cols) # Physical schema might include unused columns, unsupported datatypes like JSON physical_schema = ibis_bigquery.BigQuerySchema.to_ibis( - list(i for i in source.table.physical_schema if i.name in used_columns) + list(source.table.physical_schema) ) if source.at_time is not None or source.sql_predicate is not None: import bigframes.session._io.bigquery sql = bigframes.session._io.bigquery.to_query( full_table_name, - columns=used_columns, + columns=scan_cols, sql_predicate=source.sql_predicate, time_travel_timestamp=source.at_time, ) return ibis_bigquery.Backend().sql(schema=physical_schema, query=sql) else: - return ibis_api.table(physical_schema, full_table_name) + return ibis_api.table(physical_schema, full_table_name).select(scan_cols) def compile_read_table_unordered( self, source: nodes.BigqueryDataSource, scan: nodes.ScanList From 54685ccd881ca115f9a95baee856990e9dd7c362 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 6 Feb 2025 20:54:12 +0000 Subject: [PATCH 5/6] port ibis cte ordering fix --- third_party/bigframes_vendored/ibis/backends/sql/rewrites.py | 2 +- third_party/bigframes_vendored/ibis/common/graph.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py index 652f04757b..a252f116dd 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/rewrites.py @@ -359,7 +359,7 @@ def wrap(node, _, **kwargs): return CTE(new) if node in ctes else new result = simplified.replace(wrap) - ctes = reversed([cte.parent for cte in result.find(CTE)]) + ctes = [cte.parent for cte in result.find(CTE, ordered=True)] return result, ctes diff --git a/third_party/bigframes_vendored/ibis/common/graph.py b/third_party/bigframes_vendored/ibis/common/graph.py index 1a3fc6c543..6e7995ec03 100644 --- a/third_party/bigframes_vendored/ibis/common/graph.py +++ b/third_party/bigframes_vendored/ibis/common/graph.py @@ -343,6 +343,7 @@ def find( finder: FinderLike, filter: Optional[FinderLike] = None, context: Optional[dict] = None, + ordered: bool = False, ) -> list[Node]: """Find all nodes matching a given pattern or type in the graph. @@ -360,6 +361,8 @@ def find( the given filter and stop otherwise. context Optional context to use if `finder` or `filter` is a pattern. + ordered + Emit nodes in topological order if `True`. Returns ------- @@ -369,6 +372,8 @@ def find( """ graph = Graph.from_bfs(self, filter=filter, context=context) finder = _coerce_finder(finder, context) + if ordered: + graph, _ = graph.toposort() return [node for node in graph.nodes() if finder(node)] @experimental From 7f9fdbadb24f64fb5938e828f02c0ca33c6261a3 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 6 Feb 2025 22:12:18 +0000 Subject: [PATCH 6/6] remove dead code --- bigframes/core/rewrite/pruning.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/bigframes/core/rewrite/pruning.py b/bigframes/core/rewrite/pruning.py index 3063c95afb..0b8534116d 100644 --- a/bigframes/core/rewrite/pruning.py +++ b/bigframes/core/rewrite/pruning.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses import functools -from typing import AbstractSet, Iterable, TypeVar +from typing import AbstractSet import bigframes.core.identifiers import bigframes.core.nodes @@ -193,15 +193,3 @@ def filter_scanlist( # We need to select something, or stuff breaks result = bigframes.core.nodes.ScanList(scanlist.items[:1]) return result - - -T = TypeVar("T") - - -def dedupe(items: Iterable[T]) -> Iterable[T]: - seen = set() - - for item in items: - if item not in seen: - seen.add(item) - yield item