Skip to content

Commit a52e7fe

Browse files
author
Kei
committed
Temporarily change observed=True, for groupby.transform
1 parent 888b6bc commit a52e7fe

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

pandas/core/groupby/generic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,8 +2044,11 @@ def _gotitem(self, key, ndim: int, subset=None):
20442044
elif ndim == 1:
20452045
if subset is None:
20462046
subset = self.obj[key]
2047+
2048+
orig_obj = self.orig_obj if not self.observed else None
20472049
return SeriesGroupBy(
20482050
subset,
2051+
orig_obj,
20492052
self.keys,
20502053
level=self.level,
20512054
grouper=self._grouper,

pandas/core/groupby/groupby.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,7 @@ class GroupBy(BaseGroupBy[NDFrameT]):
10961096
def __init__(
10971097
self,
10981098
obj: NDFrameT,
1099+
orig_obj: NDFrameT | None = None,
10991100
keys: _KeysArgType | None = None,
11001101
level: IndexLabel | None = None,
11011102
grouper: ops.BaseGrouper | None = None,
@@ -1117,6 +1118,7 @@ def __init__(
11171118
self.sort = sort
11181119
self.group_keys = group_keys
11191120
self.dropna = dropna
1121+
self.orig_obj = obj if orig_obj is None else orig_obj
11201122

11211123
if grouper is None:
11221124
grouper, exclusions, obj = get_grouper(
@@ -1879,24 +1881,70 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
18791881

18801882
else:
18811883
# i.e. func in base.reduction_kernels
1884+
if self.observed:
1885+
return self._reduction_kernel_transform(
1886+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1887+
)
18821888

1883-
# GH#30918 Use _transform_fast only when we know func is an aggregation
1884-
# If func is a reduction, we need to broadcast the
1885-
# result to the whole group. Compute func result
1886-
# and deal with possible broadcasting below.
1887-
with com.temp_setattr(self, "as_index", True):
1888-
# GH#49834 - result needs groups in the index for
1889-
# _wrap_transform_fast_result
1890-
if func in ["idxmin", "idxmax"]:
1891-
func = cast(Literal["idxmin", "idxmax"], func)
1892-
result = self._idxmax_idxmin(func, True, *args, **kwargs)
1893-
else:
1894-
if engine is not None:
1895-
kwargs["engine"] = engine
1896-
kwargs["engine_kwargs"] = engine_kwargs
1897-
result = getattr(self, func)(*args, **kwargs)
1889+
grouper, exclusions, obj = get_grouper(
1890+
self.orig_obj,
1891+
self.keys,
1892+
level=self.level,
1893+
sort=self.sort,
1894+
observed=True,
1895+
dropna=self.dropna,
1896+
)
1897+
exclusions = frozenset(exclusions) if exclusions else frozenset()
1898+
obj_has_not_changed = self.orig_obj.equals(self.obj)
1899+
1900+
with (
1901+
com.temp_setattr(self, "observed", True),
1902+
com.temp_setattr(self, "_grouper", grouper),
1903+
com.temp_setattr(self, "exclusions", exclusions),
1904+
com.temp_setattr(self, "obj", obj, condition=obj_has_not_changed),
1905+
):
1906+
return self._reduction_kernel_transform(
1907+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1908+
)
1909+
1910+
# with com.temp_setattr(self, "as_index", True):
1911+
# # GH#49834 - result needs groups in the index for
1912+
# # _wrap_transform_fast_result
1913+
# if func in ["idxmin", "idxmax"]:
1914+
# func = cast(Literal["idxmin", "idxmax"], func)
1915+
# result = self._idxmax_idxmin(func, True, *args, **kwargs)
1916+
# else:
1917+
# if engine is not None:
1918+
# kwargs["engine"] = engine
1919+
# kwargs["engine_kwargs"] = engine_kwargs
1920+
# result = getattr(self, func)(*args, **kwargs)
1921+
1922+
# print("result with observed = False\n", result.to_string())
1923+
# r = self._wrap_transform_fast_result(result)
1924+
# print("reindexed result", r.to_string())
1925+
# return r
1926+
1927+
@final
1928+
def _reduction_kernel_transform(
1929+
self, func, *args, engine=None, engine_kwargs=None, **kwargs
1930+
):
1931+
# GH#30918 Use _transform_fast only when we know func is an aggregation
1932+
# If func is a reduction, we need to broadcast the
1933+
# result to the whole group. Compute func result
1934+
# and deal with possible broadcasting below.
1935+
with com.temp_setattr(self, "as_index", True):
1936+
# GH#49834 - result needs groups in the index for
1937+
# _wrap_transform_fast_result
1938+
if func in ["idxmin", "idxmax"]:
1939+
func = cast(Literal["idxmin", "idxmax"], func)
1940+
result = self._idxmax_idxmin(func, True, *args, **kwargs)
1941+
else:
1942+
if engine is not None:
1943+
kwargs["engine"] = engine
1944+
kwargs["engine_kwargs"] = engine_kwargs
1945+
result = getattr(self, func)(*args, **kwargs)
18981946

1899-
return self._wrap_transform_fast_result(result)
1947+
return self._wrap_transform_fast_result(result)
19001948

19011949
@final
19021950
def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT:

0 commit comments

Comments
 (0)