Skip to content

Commit 6564d12

Browse files
refactor df binary op alignment
1 parent dbb66e2 commit 6564d12

File tree

4 files changed

+264
-133
lines changed

4 files changed

+264
-133
lines changed

bigframes/core/blocks.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,153 @@ def merge(
19571957
expr = joined_expr.promote_offsets(offset_index_id)
19581958
return Block(expr, index_columns=[offset_index_id], column_labels=labels)
19591959

1960+
def _align_both_axes(
1961+
self, other: Block, how: str
1962+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
1963+
# Join rows
1964+
aligned_block, (get_column_left, get_column_right) = self.join(other, how=how)
1965+
# join columns schema
1966+
# indexers will be none for exact match
1967+
if self.column_labels.equals(other.column_labels):
1968+
columns, lcol_indexer, rcol_indexer = self.column_labels, None, None
1969+
else:
1970+
columns, lcol_indexer, rcol_indexer = self.column_labels.join(
1971+
other.column_labels, how="outer", return_indexers=True
1972+
)
1973+
lcol_indexer = (
1974+
lcol_indexer if (lcol_indexer is not None) else range(len(columns))
1975+
)
1976+
rcol_indexer = (
1977+
rcol_indexer if (rcol_indexer is not None) else range(len(columns))
1978+
)
1979+
1980+
left_input_lookup = (
1981+
lambda index: ex.free_var(get_column_left[self.value_columns[index]])
1982+
if index != -1
1983+
else ex.const(None)
1984+
)
1985+
righ_input_lookup = (
1986+
lambda index: ex.free_var(get_column_right[other.value_columns[index]])
1987+
if index != -1
1988+
else ex.const(None)
1989+
)
1990+
1991+
left_inputs = [left_input_lookup(i) for i in lcol_indexer]
1992+
right_inputs = [righ_input_lookup(i) for i in rcol_indexer]
1993+
return aligned_block, columns, tuple(zip(left_inputs, right_inputs))
1994+
1995+
def _align_axis_0(
1996+
self, other: Block, how: str
1997+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
1998+
assert len(other.value_columns) == 1
1999+
aligned_block, (get_column_left, get_column_right) = self.join(other, how=how)
2000+
2001+
series_column_id = other.value_columns[0]
2002+
inputs = tuple(
2003+
(
2004+
ex.free_var(get_column_left[col]),
2005+
ex.free_var(get_column_right[series_column_id]),
2006+
)
2007+
for col in self.value_columns
2008+
)
2009+
return aligned_block, self.column_labels, inputs
2010+
2011+
def _align_series_block_axis_1(
2012+
self, other: Block, how: str
2013+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
2014+
assert len(other.value_columns) == 1
2015+
if other._transpose_cache is None:
2016+
raise ValueError(
2017+
"Wrong align method, this approach requires transpose cache"
2018+
)
2019+
2020+
# Join rows
2021+
aligned_block, (get_column_left, get_column_right) = self.join_row_cross(
2022+
other.transpose()
2023+
)
2024+
# join columns schema
2025+
# indexers will be none for exact match
2026+
if self.column_labels.equals(other.transpose().column_labels):
2027+
columns, lcol_indexer, rcol_indexer = self.column_labels, None, None
2028+
else:
2029+
columns, lcol_indexer, rcol_indexer = self.column_labels.join(
2030+
other.transpose().column_labels, how="outer", return_indexers=True
2031+
)
2032+
lcol_indexer = (
2033+
lcol_indexer if (lcol_indexer is not None) else range(len(columns))
2034+
)
2035+
rcol_indexer = (
2036+
rcol_indexer if (rcol_indexer is not None) else range(len(columns))
2037+
)
2038+
2039+
left_input_lookup = (
2040+
lambda index: ex.free_var(get_column_left[self.value_columns[index]])
2041+
if index != -1
2042+
else ex.const(None)
2043+
)
2044+
righ_input_lookup = (
2045+
lambda index: ex.free_var(
2046+
get_column_right[other.transpose().value_columns[index]]
2047+
)
2048+
if index != -1
2049+
else ex.const(None)
2050+
)
2051+
2052+
left_inputs = [left_input_lookup(i) for i in lcol_indexer]
2053+
right_inputs = [righ_input_lookup(i) for i in rcol_indexer]
2054+
return aligned_block, columns, tuple(zip(left_inputs, right_inputs))
2055+
2056+
def _align_pd_series_axis_1(
2057+
self, other: pd.Series, how: str
2058+
) -> Tuple[Block, pd.Index, Sequence[Tuple[ex.Expression, ex.Expression]]]:
2059+
if self.column_labels.equals(other.index):
2060+
columns, lcol_indexer, rcol_indexer = self.column_labels, None, None
2061+
else:
2062+
if not (self.column_labels.is_unique and other.index.is_unique):
2063+
raise ValueError("Cannot align non-unique indices")
2064+
columns, lcol_indexer, rcol_indexer = self.column_labels.join(
2065+
other.index, how=how, return_indexers=True
2066+
)
2067+
lcol_indexer = (
2068+
lcol_indexer if (lcol_indexer is not None) else range(len(columns))
2069+
)
2070+
rcol_indexer = (
2071+
rcol_indexer if (rcol_indexer is not None) else range(len(columns))
2072+
)
2073+
2074+
left_input_lookup = (
2075+
lambda index: ex.free_var(self.value_columns[index])
2076+
if index != -1
2077+
else ex.const(None)
2078+
)
2079+
righ_input_lookup = (
2080+
lambda index: ex.const(other.iloc[index]) if index != -1 else ex.const(None)
2081+
)
2082+
2083+
left_inputs = [left_input_lookup(i) for i in lcol_indexer]
2084+
right_inputs = [righ_input_lookup(i) for i in rcol_indexer]
2085+
return self, columns, tuple(zip(left_inputs, right_inputs))
2086+
2087+
def _apply_binop(
2088+
self,
2089+
op: ops.BinaryOp,
2090+
inputs: Sequence[Tuple[ex.Expression, ex.Expression]],
2091+
labels: pd.Index,
2092+
reverse: bool = False,
2093+
) -> Block:
2094+
block = self
2095+
binop_result_ids = []
2096+
for left_input, right_input in inputs:
2097+
expr = (
2098+
op.as_expr(right_input, left_input)
2099+
if reverse
2100+
else op.as_expr(left_input, right_input)
2101+
)
2102+
block, result_col_id = block.project_expr(expr)
2103+
binop_result_ids.append(result_col_id)
2104+
2105+
return block.select_columns(binop_result_ids).with_column_labels(labels)
2106+
19602107
def join(
19612108
self,
19622109
other: Block,
@@ -2015,6 +2162,27 @@ def join(
20152162
self, other, how=how, sort=sort, block_identity_join=block_identity_join
20162163
)
20172164

2165+
def join_row_cross(
2166+
self,
2167+
single_row_block: Block,
2168+
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
2169+
"""
2170+
Special join case where other is a single row block.
2171+
This property is not validated, caller responsible for not passing multi-row block.
2172+
Preserves index of this block, ignoring label of other.
2173+
"""
2174+
2175+
if not isinstance(single_row_block, Block):
2176+
# TODO(swast): We need to improve this error message to be more
2177+
# actionable for the user. For example, it's possible they
2178+
# could call set_index and try again to resolve this error.
2179+
raise ValueError(
2180+
f"Tried to join with an unexpected type: {type(single_row_block)}. {constants.FEEDBACK_LINK}"
2181+
)
2182+
2183+
self._throw_if_null_index("join")
2184+
return join_with_single_row(self, single_row_block)
2185+
20182186
def _force_reproject(self) -> Block:
20192187
"""Forces a reprojection of the underlying tables expression. Used to force predicate/order application before subsequent operations."""
20202188
return Block(
@@ -2372,6 +2540,56 @@ def join_indexless(
23722540
)
23732541

23742542

2543+
def join_with_single_row(
2544+
left: Block,
2545+
single_row_block: Block,
2546+
) -> Tuple[Block, Tuple[Mapping[str, str], Mapping[str, str]],]:
2547+
left_expr = left.expr
2548+
# ignore index columns by dropping them
2549+
right_expr = single_row_block.expr.select_columns(single_row_block.value_columns)
2550+
left_mappings = [
2551+
join_defs.JoinColumnMapping(
2552+
source_table=join_defs.JoinSide.LEFT,
2553+
source_id=id,
2554+
destination_id=guid.generate_guid(),
2555+
)
2556+
for id in left_expr.column_ids
2557+
]
2558+
right_mappings = [
2559+
join_defs.JoinColumnMapping(
2560+
source_table=join_defs.JoinSide.RIGHT,
2561+
source_id=id,
2562+
destination_id=guid.generate_guid(),
2563+
)
2564+
for id in right_expr.column_ids # skip index column
2565+
]
2566+
2567+
join_def = join_defs.JoinDefinition(
2568+
conditions=(),
2569+
mappings=(*left_mappings, *right_mappings),
2570+
type="cross",
2571+
)
2572+
combined_expr = left_expr.join(
2573+
right_expr,
2574+
join_def=join_def,
2575+
)
2576+
get_column_left = join_def.get_left_mapping()
2577+
get_column_right = join_def.get_right_mapping()
2578+
# Drop original indices from each side. and used the coalesced combination generated by the join.
2579+
index_cols_post_join = [get_column_left[id] for id in left.index_columns]
2580+
2581+
block = Block(
2582+
combined_expr,
2583+
index_columns=index_cols_post_join,
2584+
column_labels=left.column_labels.append(single_row_block.column_labels),
2585+
index_labels=[left.index.name],
2586+
)
2587+
return (
2588+
block,
2589+
(get_column_left, get_column_right),
2590+
)
2591+
2592+
23752593
def join_mono_indexed(
23762594
left: Block,
23772595
right: Block,

0 commit comments

Comments
 (0)