Skip to content

Commit 15b8449

Browse files
authored
feat: add ml.metrics.mean_absolute_error method (#1910)
1 parent 566b5b0 commit 15b8449

File tree

4 files changed

+50
-0
lines changed

4 files changed

+50
-0
lines changed

bigframes/ml/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
auc,
1919
confusion_matrix,
2020
f1_score,
21+
mean_absolute_error,
2122
mean_squared_error,
2223
precision_score,
2324
r2_score,
@@ -36,6 +37,7 @@
3637
"confusion_matrix",
3738
"precision_score",
3839
"f1_score",
40+
"mean_absolute_error",
3941
"mean_squared_error",
4042
"pairwise",
4143
]

bigframes/ml/metrics/_metrics.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,17 @@ def mean_squared_error(
344344
mean_squared_error.__doc__ = inspect.getdoc(
345345
vendored_metrics_regression.mean_squared_error
346346
)
347+
348+
349+
def mean_absolute_error(
350+
y_true: Union[bpd.DataFrame, bpd.Series],
351+
y_pred: Union[bpd.DataFrame, bpd.Series],
352+
) -> float:
353+
y_true_series, y_pred_series = utils.batch_convert_to_series(y_true, y_pred)
354+
355+
return (y_pred_series - y_true_series).abs().sum() / len(y_true_series)
356+
357+
358+
mean_absolute_error.__doc__ = inspect.getdoc(
359+
vendored_metrics_regression.mean_absolute_error
360+
)

tests/system/small/ml/test_metrics.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,3 +818,10 @@ def test_mean_squared_error(session: bigframes.Session):
818818
df = session.read_pandas(pd_df)
819819
mse = metrics.mean_squared_error(df["y_true"], df["y_pred"])
820820
assert mse == 0.375
821+
822+
823+
def test_mean_absolute_error(session: bigframes.Session):
824+
pd_df = pd.DataFrame({"y_true": [3, -0.5, 2, 7], "y_pred": [2.5, 0.0, 2, 8]})
825+
df = session.read_pandas(pd_df)
826+
mse = metrics.mean_absolute_error(df["y_true"], df["y_pred"])
827+
assert mse == 0.5

third_party/bigframes_vendored/sklearn/metrics/_regression.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,30 @@ def mean_squared_error(y_true, y_pred) -> float:
9191
float: Mean squared error.
9292
"""
9393
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
94+
95+
96+
def mean_absolute_error(y_true, y_pred) -> float:
97+
"""Mean absolute error regression loss.
98+
99+
**Examples:**
100+
101+
>>> import bigframes.pandas as bpd
102+
>>> import bigframes.ml.metrics
103+
>>> bpd.options.display.progress_bar = None
104+
105+
>>> y_true = bpd.DataFrame([3, -0.5, 2, 7])
106+
>>> y_pred = bpd.DataFrame([2.5, 0.0, 2, 8])
107+
>>> mae = bigframes.ml.metrics.mean_absolute_error(y_true, y_pred)
108+
>>> mae
109+
np.float64(0.5)
110+
111+
Args:
112+
y_true (Series or DataFrame of shape (n_samples,)):
113+
Ground truth (correct) target values.
114+
y_pred (Series or DataFrame of shape (n_samples,)):
115+
Estimated target values.
116+
117+
Returns:
118+
float: Mean absolute error.
119+
"""
120+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

0 commit comments

Comments
 (0)