Skip to content

Commit f519f82

Browse files
committed
Rename vars for clarity
1 parent d9aff2f commit f519f82

File tree

1 file changed

+39
-49
lines changed

1 file changed

+39
-49
lines changed

machine_learning/local_weighted_learning/local_weighted_learning.py

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
import numpy as np
33

44

5-
def weighted_matrix(
6-
point: np.ndarray, training_data_x: np.ndarray, bandwidth: float
7-
) -> np.ndarray:
5+
def weight_matrix(point: np.ndarray, x_train: np.ndarray, tau: float) -> np.ndarray:
86
"""
97
Calculate the weight for every point in the data set.
108
point --> the x value at which we want to make predictions
11-
>>> weighted_matrix(
9+
>>> weight_matrix(
1210
... np.array([1., 1.]),
1311
... np.array([[16.99, 10.34], [21.01,23.68], [24.59,25.69]]),
1412
... 0.6
@@ -17,21 +15,18 @@ def weighted_matrix(
1715
[0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
1816
[0.00000000e+000, 0.00000000e+000, 0.00000000e+000]])
1917
"""
20-
m, _ = np.shape(training_data_x) # m is the number of training samples
18+
m, _ = np.shape(x_train) # m is the number of training samples
2119
weights = np.eye(m) # Initializing weights as identity matrix
2220

2321
# calculating weights for all training examples [x(i)'s]
2422
for j in range(m):
25-
diff = point - training_data_x[j]
26-
weights[j, j] = np.exp(diff @ diff.T / (-2.0 * bandwidth**2))
23+
diff = point - x_train[j]
24+
weights[j, j] = np.exp(diff @ diff.T / (-2.0 * tau**2))
2725
return weights
2826

2927

3028
def local_weight(
31-
point: np.ndarray,
32-
training_data_x: np.ndarray,
33-
training_data_y: np.ndarray,
34-
bandwidth: float,
29+
point: np.ndarray, x_train: np.ndarray, y_train: np.ndarray, tau: float
3530
) -> np.ndarray:
3631
"""
3732
Calculate the local weights using the weight_matrix function on training data.
@@ -45,16 +40,16 @@ def local_weight(
4540
array([[0.00873174],
4641
[0.08272556]])
4742
"""
48-
weight = weighted_matrix(point, training_data_x, bandwidth)
49-
w = np.linalg.inv(training_data_x.T @ (weight @ training_data_x)) @ (
50-
training_data_x.T @ weight @ training_data_y.T
43+
weight_mat = weight_matrix(point, x_train, tau)
44+
weight = np.linalg.inv(x_train.T @ weight_mat @ x_train) @ (
45+
x_train.T @ weight_mat @ y_train.T
5146
)
5247

53-
return w
48+
return weight
5449

5550

5651
def local_weight_regression(
57-
training_data_x: np.ndarray, training_data_y: np.ndarray, bandwidth: float
52+
x_train: np.ndarray, y_train: np.ndarray, tau: float
5853
) -> np.ndarray:
5954
"""
6055
Calculate predictions for each data point on axis
@@ -65,43 +60,39 @@ def local_weight_regression(
6560
... )
6661
array([1.07173261, 1.65970737, 3.50160179])
6762
"""
68-
m, _ = np.shape(training_data_x)
69-
ypred = np.zeros(m)
63+
m, _ = np.shape(x_train)
64+
y_pred = np.zeros(m)
7065

71-
for i, item in enumerate(training_data_x):
72-
ypred[i] = item @ local_weight(
73-
item, training_data_x, training_data_y, bandwidth
74-
)
66+
for i, item in enumerate(x_train):
67+
y_pred[i] = item @ local_weight(item, x_train, y_train, tau)
7568

76-
return ypred
69+
return y_pred
7770

7871

7972
def load_data(
80-
dataset_name: str, cola_name: str, colb_name: str
73+
dataset_name: str, x_name: str, y_name: str
8174
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
8275
"""
8376
Load data from seaborn and split it into x and y points
8477
"""
8578
import seaborn as sns
8679

8780
data = sns.load_dataset(dataset_name)
88-
col_a = np.array(data[cola_name]) # total_bill
89-
col_b = np.array(data[colb_name]) # tip
81+
x_data = np.array(data[x_name]) # total_bill
82+
y_data = np.array(data[y_name]) # tip
9083

91-
mcol_a = col_a.copy()
92-
mcol_b = col_b.copy()
84+
mcol_a = x_data.copy()
85+
mcol_b = y_data.copy()
9386

9487
one = np.ones(np.shape(mcol_b)[0], dtype=int)
9588

9689
# pairing elements of one and mcol_a
97-
training_data_x = np.column_stack((one, mcol_a))
90+
x_train = np.column_stack((one, mcol_a))
9891

99-
return training_data_x, mcol_b, col_a, col_b
92+
return x_train, mcol_b, x_data, y_data
10093

10194

102-
def get_preds(
103-
training_data_x: np.ndarray, mcol_b: np.ndarray, tau: float
104-
) -> np.ndarray:
95+
def get_preds(x_train: np.ndarray, y_train: np.ndarray, tau: float) -> np.ndarray:
10596
"""
10697
Get predictions with minimum error for each training data
10798
>>> get_preds(
@@ -111,33 +102,32 @@ def get_preds(
111102
... )
112103
array([1.07173261, 1.65970737, 3.50160179])
113104
"""
114-
ypred = local_weight_regression(training_data_x, mcol_b, tau)
115-
return ypred
105+
y_pred = local_weight_regression(x_train, y_train, tau)
106+
return y_pred
116107

117108

118109
def plot_preds(
119-
training_data_x: np.ndarray,
110+
x_train: np.ndarray,
120111
predictions: np.ndarray,
121-
col_x: np.ndarray,
122-
col_y: np.ndarray,
123-
cola_name: str,
124-
colb_name: str,
112+
x_data: np.ndarray,
113+
y_data: np.ndarray,
114+
x_name: str,
115+
y_name: str,
125116
) -> plt.plot:
126117
"""
127118
Plot predictions and display the graph
128119
"""
129-
xsort = training_data_x.copy()
130-
xsort.sort(axis=0)
131-
plt.scatter(col_x, col_y, color="blue")
120+
x_train_sorted = np.sort(x_train, axis=0)
121+
plt.scatter(x_data, y_data, color="blue")
132122
plt.plot(
133-
xsort[:, 1],
134-
predictions[training_data_x[:, 1].argsort(0)],
123+
x_train_sorted[:, 1],
124+
predictions[x_train[:, 1].argsort(0)],
135125
color="yellow",
136126
linewidth=5,
137127
)
138128
plt.title("Local Weighted Regression")
139-
plt.xlabel(cola_name)
140-
plt.ylabel(colb_name)
129+
plt.xlabel(x_name)
130+
plt.ylabel(y_name)
141131
plt.show()
142132

143133

@@ -146,6 +136,6 @@ def plot_preds(
146136

147137
doctest.testmod()
148138

149-
training_data_x, mcol_b, col_a, col_b = load_data("tips", "total_bill", "tip")
139+
training_data_x, mcol_b, total_bill, tip = load_data("tips", "total_bill", "tip")
150140
predictions = get_preds(training_data_x, mcol_b, 0.5)
151-
plot_preds(training_data_x, predictions, col_a, col_b, "total_bill", "tip")
141+
plot_preds(training_data_x, predictions, total_bill, tip, "total_bill", "tip")

0 commit comments

Comments
 (0)