Skip to content

Commit d9aff2f

Browse files
committed
Fix mypy errors in local_weighted_learning.py
1 parent 30ee318 commit d9aff2f

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

machine_learning/local_weighted_learning/local_weighted_learning.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44

55
def weighted_matrix(
6-
point: np.array, training_data_x: np.array, bandwidth: float
7-
) -> np.array:
6+
point: np.ndarray, training_data_x: np.ndarray, bandwidth: float
7+
) -> np.ndarray:
88
"""
99
Calculate the weight for every point in the data set.
1010
point --> the x value at which we want to make predictions
@@ -28,11 +28,11 @@ def weighted_matrix(
2828

2929

3030
def local_weight(
31-
point: np.array,
32-
training_data_x: np.array,
33-
training_data_y: np.array,
31+
point: np.ndarray,
32+
training_data_x: np.ndarray,
33+
training_data_y: np.ndarray,
3434
bandwidth: float,
35-
) -> np.array:
35+
) -> np.ndarray:
3636
"""
3737
Calculate the local weights using the weight_matrix function on training data.
3838
Return the weighted matrix.
@@ -54,8 +54,8 @@ def local_weight(
5454

5555

5656
def local_weight_regression(
57-
training_data_x: np.array, training_data_y: np.array, bandwidth: float
58-
) -> np.array:
57+
training_data_x: np.ndarray, training_data_y: np.ndarray, bandwidth: float
58+
) -> np.ndarray:
5959
"""
6060
Calculate predictions for each data point on axis
6161
>>> local_weight_regression(
@@ -78,7 +78,7 @@ def local_weight_regression(
7878

7979
def load_data(
8080
dataset_name: str, cola_name: str, colb_name: str
81-
) -> tuple[np.array, np.array, np.array, np.array]:
81+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
8282
"""
8383
Load data from seaborn and split it into x and y points
8484
"""
@@ -99,7 +99,9 @@ def load_data(
9999
return training_data_x, mcol_b, col_a, col_b
100100

101101

102-
def get_preds(training_data_x: np.array, mcol_b: np.array, tau: float) -> np.array:
102+
def get_preds(
103+
training_data_x: np.ndarray, mcol_b: np.ndarray, tau: float
104+
) -> np.ndarray:
103105
"""
104106
Get predictions with minimum error for each training data
105107
>>> get_preds(
@@ -114,10 +116,10 @@ def get_preds(training_data_x: np.array, mcol_b: np.array, tau: float) -> np.arr
114116

115117

116118
def plot_preds(
117-
training_data_x: np.array,
118-
predictions: np.array,
119-
col_x: np.array,
120-
col_y: np.array,
119+
training_data_x: np.ndarray,
120+
predictions: np.ndarray,
121+
col_x: np.ndarray,
122+
col_y: np.ndarray,
121123
cola_name: str,
122124
colb_name: str,
123125
) -> plt.plot:

0 commit comments

Comments
 (0)