Skip to content

Commit ebd91ab

Browse files
committed
Use infer_dtype to detect strings; add new regex for tests
1 parent 8f8dfd8 commit ebd91ab

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

pandas/core/nanops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
iNaT,
2121
lib,
2222
)
23+
from pandas._libs.lib import infer_dtype
2324
from pandas._typing import (
2425
ArrayLike,
2526
Dtype,
@@ -43,7 +44,6 @@
4344
is_numeric_dtype,
4445
is_object_dtype,
4546
is_scalar,
46-
is_string_dtype,
4747
is_timedelta64_dtype,
4848
needs_i8_conversion,
4949
pandas_dtype,
@@ -693,10 +693,9 @@ def nanmean(
693693

694694
count = _get_counts(values.shape, mask, axis, dtype=dtype_count)
695695
the_sum = values.sum(axis, dtype=dtype_sum)
696-
if isinstance(the_sum, str) or is_string_dtype(the_sum):
696+
if infer_dtype(the_sum) in ("string", "byte", "mixed-integer", "mixed"):
697697
raise TypeError("cannot find the mean of type 'str'")
698-
else:
699-
_ensure_numeric(the_sum)
698+
the_sum = _ensure_numeric(the_sum)
700699

701700
if axis is not None and getattr(the_sum, "ndim", False):
702701
count = cast(np.ndarray, count)

pandas/tests/apply/test_invalid_arg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,10 @@ def test_agg_cython_table_raises_frame(df, func, expected, axis):
257257
)
258258
def test_agg_cython_table_raises_series(series, func, expected):
259259
# GH21224
260-
msg = r"[Cc]ould not convert|can't multiply sequence by non-int of type"
260+
msg = (
261+
r"[Cc]ould not convert|can't multiply sequence by non-int of type"
262+
r"|cannot find the mean of type 'str'"
263+
)
261264
with pytest.raises(expected, match=msg):
262265
# e.g. Series('a b'.split()).cumprod() will raise
263266
series.agg(func)

pandas/tests/groupby/aggregate/test_cython.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_cython_agg_nothing_to_agg():
9292
with pytest.raises(NotImplementedError, match="does not implement"):
9393
frame.groupby("a")["b"].mean(numeric_only=True)
9494

95-
with pytest.raises(TypeError, match="Could not convert (foo|bar)*"):
95+
with pytest.raises(TypeError, match="cannot find the mean of*"):
9696
frame.groupby("a")["b"].mean()
9797

9898
frame = DataFrame({"a": np.random.randint(0, 5, 50), "b": ["foo", "bar"] * 25})

0 commit comments

Comments
 (0)