Skip to content

Commit 0ca1d06

Browse files
committed
Refactor to remove unneeded wrapper function
1 parent 4ae72ab commit 0ca1d06

File tree

1 file changed

+1
-15
lines changed

1 file changed

+1
-15
lines changed

machine_learning/local_weighted_learning/local_weighted_learning.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,20 +89,6 @@ def load_data(
8989
return x_train, x_data, y_data
9090

9191

92-
def get_preds(x_train: np.ndarray, y_train: np.ndarray, tau: float) -> np.ndarray:
93-
"""
94-
Get predictions with minimum error for each training data
95-
>>> get_preds(
96-
... np.array([[16.99, 10.34], [21.01, 23.68], [24.59, 25.69]]),
97-
... np.array([[1.01, 1.66, 3.5]]),
98-
... 0.6
99-
... )
100-
array([1.07173261, 1.65970737, 3.50160179])
101-
"""
102-
y_pred = local_weight_regression(x_train, y_train, tau)
103-
return y_pred
104-
105-
10692
def plot_preds(
10793
x_train: np.ndarray,
10894
preds: np.ndarray,
@@ -134,5 +120,5 @@ def plot_preds(
134120
doctest.testmod()
135121

136122
training_data_x, total_bill, tip = load_data("tips", "total_bill", "tip")
137-
predictions = get_preds(training_data_x, tip, 0.5)
123+
predictions = local_weight_regression(training_data_x, tip, 0.5)
138124
plot_preds(training_data_x, predictions, total_bill, tip, "total_bill", "tip")

0 commit comments

Comments
 (0)