Skip to content

Commit 9653112

Browse files
committed
more wip
1 parent 85119ea commit 9653112

File tree

6 files changed

+47
-40
lines changed

6 files changed

+47
-40
lines changed

pandas/core/generic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9938,7 +9938,6 @@ def _add_numeric_operations(cls):
99389938
)
99399939
@Appender(_num_doc_mad)
99409940
def mad(self, axis=None, skipna=None, level=None):
9941-
breakpoint()
99429941
if skipna is None:
99439942
skipna = True
99449943
if axis is None:

pandas/core/groupby/generic.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,6 @@ def aggregate(self, func=None, *args, **kwargs):
927927
raise TypeError("Must provide 'func' or tuples of '(column, aggfunc).")
928928

929929
func = maybe_mangle_lambdas(func)
930-
931930
result, how = self._aggregate(func, *args, **kwargs)
932931
if how is None:
933932
return result
@@ -1131,7 +1130,6 @@ def _cython_agg_blocks(
11311130
def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame:
11321131
if self.grouper.nkeys != 1:
11331132
raise AssertionError("Number of keys must be 1")
1134-
11351133
axis = self.axis
11361134
obj = self._obj_with_exclusions
11371135

@@ -1145,7 +1143,6 @@ def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame:
11451143
data = self.get_group(name, obj=obj)
11461144
fres = func(data, *args, **kwargs)
11471145
result[name] = fres
1148-
11491146
return self._wrap_frame_output(result, obj)
11501147

11511148
def _aggregate_item_by_item(self, func, *args, **kwargs) -> DataFrame:
@@ -1185,7 +1182,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
11851182
if len(keys) == 0:
11861183
return DataFrame(index=keys)
11871184

1188-
key_names = self.grouper.names
1185+
# key_names = self.grouper.names
11891186

11901187
# GH12824.
11911188
def first_not_none(values):
@@ -1195,35 +1192,37 @@ def first_not_none(values):
11951192
return None
11961193

11971194
v = first_not_none(values)
1198-
11991195
if v is None:
12001196
# GH9684. If all values are None, then this will throw an error.
12011197
# We'd prefer it return an empty dataframe.
12021198
return DataFrame()
12031199
elif isinstance(v, DataFrame):
12041200
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
12051201
elif self.grouper.groupings is not None:
1206-
if len(self.grouper.groupings) > 1:
1207-
key_index = self.grouper.result_index
1202+
# if len(self.grouper.groupings) > 1:
1203+
key_index = self.grouper.result_index
1204+
if not self.as_index:
1205+
key_index = None
12081206

1209-
else:
1210-
ping = self.grouper.groupings[0]
1211-
if len(keys) == ping.ngroups:
1212-
key_index = ping.group_index
1213-
key_index.name = key_names[0]
1207+
# else:
1208+
# breakpoint()
1209+
# ping = self.grouper.groupings[0]
1210+
# if len(keys) == ping.ngroups:
1211+
# key_index = ping.result_index
1212+
# key_index.name = key_names[0]
12141213

1215-
key_lookup = Index(keys)
1216-
indexer = key_lookup.get_indexer(key_index)
1214+
# key_lookup = Index(keys)
1215+
# indexer = key_lookup.get_indexer(key_index)
12171216

1218-
# reorder the values
1219-
values = [values[i] for i in indexer]
1220-
else:
1217+
# # reorder the values
1218+
# values = [values[i] for i in indexer]
1219+
# else:
12211220

1222-
key_index = Index(keys, name=key_names[0])
1221+
# key_index = Index(keys, name=key_names[0])
12231222

1224-
# don't use the key indexer
1225-
if not self.as_index:
1226-
key_index = None
1223+
# # don't use the key indexer
1224+
# if not self.as_index:
1225+
# key_index = None
12271226

12281227
# make Nones an empty object
12291228
v = first_not_none(values)
@@ -1635,7 +1634,7 @@ def _gotitem(self, key, ndim: int, subset=None):
16351634
raise AssertionError("invalid ndim for _gotitem")
16361635

16371636
def _wrap_frame_output(self, result, obj) -> DataFrame:
1638-
result_index = self.grouper.levels[0]
1637+
result_index = self.grouper.result_index
16391638

16401639
if self.axis == 0:
16411640
return DataFrame(result, index=obj.columns, columns=result_index).T

pandas/core/groupby/groupby.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ def curried(x):
641641
return self.apply(curried)
642642

643643
try:
644+
# breakpoint()
644645
return self.apply(curried)
645646
except TypeError as err:
646647
if not re.search(
@@ -728,10 +729,11 @@ def f(g):
728729
)
729730
else:
730731
f = func
731-
732+
# breakpoint()
732733
# ignore SettingWithCopy here in case the user mutates
733734
with option_context("mode.chained_assignment", None):
734735
try:
736+
# breakpoint()
735737
result = self._python_apply_general(f)
736738
except TypeError:
737739
# gh-20949
@@ -748,8 +750,9 @@ def f(g):
748750
return result
749751

750752
def _python_apply_general(self, f):
753+
# breakpoint()
751754
keys, values, mutated = self.grouper.apply(f, self._selected_obj, self.axis)
752-
755+
# breakpoint()
753756
return self._wrap_applied_output(
754757
keys, values, not_indexed_same=mutated or self.mutated
755758
)
@@ -943,7 +946,6 @@ def _python_agg_general(self, func, *args, **kwargs):
943946
values = ensure_float(values)
944947

945948
output[key] = self._try_cast(values[mask], result)
946-
947949
return self._wrap_aggregated_output(output)
948950

949951
def _concat_objects(self, keys, values, not_indexed_same: bool = False):

pandas/core/groupby/ops.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,21 @@ def _get_grouper(self):
140140
return self.groupings[0].grouper
141141

142142
def _get_group_keys(self):
143-
if len(self.groupings) == 1:
144-
return self.levels[0]
145-
else:
146-
comp_ids, _, ngroups = self.group_info
143+
# if len(self.groupings) == 1:
144+
# return self.levels[0]
145+
# else:
146+
comp_ids, _, ngroups = self.group_info
147147

148-
# provide "flattened" iterator for multi-group setting
149-
return get_flattened_iterator(comp_ids, ngroups, self.levels, self.codes)
148+
# provide "flattened" iterator for multi-group setting
149+
flattened_iterator = get_flattened_iterator(
150+
comp_ids, ngroups, self.levels, self.codes
151+
)
152+
if len(self.groupings) == 1:
153+
return Index([i[0] for i in flattened_iterator], name=self.levels[0].name)
154+
return flattened_iterator
150155

151156
def apply(self, f, data: FrameOrSeries, axis: int = 0):
157+
# breakpoint()
152158
mutated = self.mutated
153159
splitter = self._get_splitter(data, axis=axis)
154160
group_keys = self._get_group_keys()
@@ -261,6 +267,7 @@ def is_monotonic(self) -> bool:
261267

262268
@cache_readonly
263269
def group_info(self):
270+
# breakpoint()
264271
comp_ids, obs_group_ids = self._get_compressed_codes()
265272

266273
ngroups = len(obs_group_ids)
@@ -278,6 +285,7 @@ def codes_info(self) -> np.ndarray:
278285

279286
def _get_compressed_codes(self) -> Tuple[np.ndarray, np.ndarray]:
280287
all_codes = self.codes
288+
# breakpoint()
281289
group_index = get_group_index(all_codes, self.shape, sort=True, xnull=True)
282290
return compress_group_index(group_index, sort=self.sort)
283291

@@ -290,12 +298,14 @@ def ngroups(self) -> int:
290298

291299
@property
292300
def reconstructed_codes(self) -> List[np.ndarray]:
301+
# breakpoint()
293302
codes = self.codes
294303
comp_ids, obs_ids, _ = self.group_info
295304
return decons_obs_group_ids(comp_ids, obs_ids, self.shape, codes, xnull=True)
296305

297306
@cache_readonly
298307
def result_index(self) -> Index:
308+
# breakpoint()
299309
# if not self.compressed and len(self.groupings) == 1:
300310
# return self.groupings[0].result_index.rename(self.names[0])
301311
codes = self.reconstructed_codes
@@ -304,6 +314,7 @@ def result_index(self) -> Index:
304314
levels=levels, codes=codes, verify_integrity=False, names=self.names
305315
)
306316
if not self.compressed and len(self.groupings) == 1:
317+
# breakpoint()
307318
return result.get_level_values(0)
308319
return result
309320

@@ -599,6 +610,7 @@ def _aggregate(
599610
is_datetimelike: bool,
600611
min_count: int = -1,
601612
):
613+
# breakpoint()
602614
if agg_func is libgroupby.group_nth:
603615
# different signature from the others
604616
# TODO: should we be using min_count instead of hard-coding it?
@@ -831,7 +843,7 @@ def reconstructed_codes(self) -> List[np.ndarray]:
831843

832844
@cache_readonly
833845
def result_index(self):
834-
breakpoint()
846+
# breakpoint()
835847
if len(self.binlabels) != 0 and isna(self.binlabels[0]):
836848
return self.binlabels[1:]
837849

pandas/tests/groupby/test_groupby.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,6 @@ def test_groupby_as_index_agg(df):
654654

655655
gr = df.groupby(ts.values, as_index=True)
656656
right = getattr(gr, attr)().reset_index(drop=True)
657-
658657
tm.assert_frame_equal(left, right)
659658

660659

@@ -1750,7 +1749,6 @@ def test_empty_dataframe_groupby():
17501749
result = df.groupby("A").sum()
17511750
expected = DataFrame(columns=["B", "C"], dtype=np.float64)
17521751
expected.index.name = "A"
1753-
17541752
tm.assert_frame_equal(result, expected)
17551753

17561754

pandas/tests/test_multilevel.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -994,21 +994,18 @@ def test_count(self):
994994
with pytest.raises(KeyError, match=msg):
995995
frame.count(level="x")
996996

997-
@pytest.mark.parametrize("op", ["mad"])
997+
@pytest.mark.parametrize("op", AGG_FUNCTIONS)
998998
@pytest.mark.parametrize("level", [0, 1])
999999
@pytest.mark.parametrize("skipna", [True, False])
10001000
@pytest.mark.parametrize("sort", [True, False])
10011001
def test_series_group_min_max(self, op, level, skipna, sort):
10021002
# GH 17537
10031003
grouped = self.series.groupby(level=level, sort=sort)
10041004
# skipna=True
1005-
breakpoint()
10061005
leftside = grouped.agg(lambda x: getattr(x, op)(skipna=skipna))
1007-
breakpoint()
10081006
rightside = getattr(self.series, op)(level=level, skipna=skipna)
10091007
if sort:
10101008
rightside = rightside.sort_index(level=level)
1011-
breakpoint()
10121009
tm.assert_series_equal(leftside, rightside)
10131010

10141011
@pytest.mark.parametrize("op", AGG_FUNCTIONS)
@@ -1044,7 +1041,7 @@ def aggf(x):
10441041

10451042
# for good measure, groupby detail
10461043
level_index = frame._get_axis(axis).levels[level].rename(level_name)
1047-
breakpoint()
1044+
10481045
tm.assert_index_equal(leftside._get_axis(axis), level_index)
10491046
tm.assert_index_equal(rightside._get_axis(axis), level_index)
10501047

0 commit comments

Comments
 (0)