Skip to content

Commit 4ae72ab

Browse files
committed
Refactor to remove duplicate var
1 parent f519f82 commit 4ae72ab

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

machine_learning/local_weighted_learning/local_weighted_learning.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def local_weight_regression(
7171

7272
def load_data(
7373
dataset_name: str, x_name: str, y_name: str
74-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
74+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
7575
"""
7676
Load data from seaborn and split it into x and y points
7777
"""
@@ -81,15 +81,12 @@ def load_data(
8181
x_data = np.array(data[x_name]) # total_bill
8282
y_data = np.array(data[y_name]) # tip
8383

84-
mcol_a = x_data.copy()
85-
mcol_b = y_data.copy()
84+
one = np.ones(np.shape(y_data)[0], dtype=int)
8685

87-
one = np.ones(np.shape(mcol_b)[0], dtype=int)
86+
# pairing elements of one and x_data
87+
x_train = np.column_stack((one, x_data))
8888

89-
# pairing elements of one and mcol_a
90-
x_train = np.column_stack((one, mcol_a))
91-
92-
return x_train, mcol_b, x_data, y_data
89+
return x_train, x_data, y_data
9390

9491

9592
def get_preds(x_train: np.ndarray, y_train: np.ndarray, tau: float) -> np.ndarray:
@@ -108,7 +105,7 @@ def get_preds(x_train: np.ndarray, y_train: np.ndarray, tau: float) -> np.ndarra
108105

109106
def plot_preds(
110107
x_train: np.ndarray,
111-
predictions: np.ndarray,
108+
preds: np.ndarray,
112109
x_data: np.ndarray,
113110
y_data: np.ndarray,
114111
x_name: str,
@@ -121,7 +118,7 @@ def plot_preds(
121118
plt.scatter(x_data, y_data, color="blue")
122119
plt.plot(
123120
x_train_sorted[:, 1],
124-
predictions[x_train[:, 1].argsort(0)],
121+
preds[x_train[:, 1].argsort(0)],
125122
color="yellow",
126123
linewidth=5,
127124
)
@@ -136,6 +133,6 @@ def plot_preds(
136133

137134
doctest.testmod()
138135

139-
training_data_x, mcol_b, total_bill, tip = load_data("tips", "total_bill", "tip")
140-
predictions = get_preds(training_data_x, mcol_b, 0.5)
136+
training_data_x, total_bill, tip = load_data("tips", "total_bill", "tip")
137+
predictions = get_preds(training_data_x, tip, 0.5)
141138
plot_preds(training_data_x, predictions, total_bill, tip, "total_bill", "tip")

0 commit comments

Comments
 (0)