Skip to content

Commit 6873b30

Browse files
fix: Product operation produces float result for all input types (#501)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent e7a8e46 commit 6873b30

File tree

5 files changed

+8
-12
lines changed

5 files changed

+8
-12
lines changed

bigframes/core/compile/aggregate_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _(
190190
.else_(magnitude * pow(-1, negative_count_parity))
191191
.end()
192192
)
193-
return float_result.cast(column.type()) # type: ignore
193+
return float_result
194194

195195

196196
@compile_unary_agg.register

bigframes/operations/aggregations.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,7 @@ class ProductOp(UnaryAggregateOp):
139139
name: ClassVar[str] = "product"
140140

141141
def output_type(self, *input_types: dtypes.ExpressionType):
142-
if pd.api.types.is_bool_dtype(input_types[0]):
143-
return dtypes.INT_DTYPE
144-
else:
145-
return input_types[0]
142+
return dtypes.FLOAT_DTYPE
146143

147144

148145
@dataclasses.dataclass(frozen=True)

tests/system/small/test_groupby.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ def test_dataframe_groupby_multi_sum(
228228
(lambda x: x.cumsum(numeric_only=True)),
229229
(lambda x: x.cummax(numeric_only=True)),
230230
(lambda x: x.cummin(numeric_only=True)),
231-
# pandas 2.2 uses floating point for cumulative product even for
232-
# integer inputs.
231+
# Pre-pandas 2.2 doesn't always proeduce float.
233232
(lambda x: x.cumprod().astype("Float64")),
234233
(lambda x: x.shift(periods=2)),
235234
],

tests/system/small/test_series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1481,7 +1481,7 @@ def test_groupby_prod(scalars_dfs):
14811481
bf_series = scalars_df[col_name].groupby(scalars_df["int64_col"]).prod()
14821482
pd_series = (
14831483
scalars_pandas_df[col_name].groupby(scalars_pandas_df["int64_col"]).prod()
1484-
)
1484+
).astype(pd.Float64Dtype())
14851485
# TODO(swast): Update groupby to use index based on group by key(s).
14861486
bf_result = bf_series.to_pandas()
14871487
assert_series_equal(

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4416,10 +4416,10 @@ def cumprod(self) -> DataFrame:
44164416
[3 rows x 2 columns]
44174417
44184418
>>> df.cumprod()
4419-
A B
4420-
0 3 1
4421-
1 3 2
4422-
2 6 6
4419+
A B
4420+
0 3.0 1.0
4421+
1 3.0 2.0
4422+
2 6.0 6.0
44234423
<BLANKLINE>
44244424
[3 rows x 2 columns]
44254425

0 commit comments

Comments
 (0)