@@ -71,7 +71,7 @@ def local_weight_regression(
71
71
72
72
def load_data (
73
73
dataset_name : str , x_name : str , y_name : str
74
- ) -> tuple [np .ndarray , np .ndarray , np .ndarray , np . ndarray ]:
74
+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray ]:
75
75
"""
76
76
Load data from seaborn and split it into x and y points
77
77
"""
@@ -81,15 +81,12 @@ def load_data(
81
81
x_data = np .array (data [x_name ]) # total_bill
82
82
y_data = np .array (data [y_name ]) # tip
83
83
84
- mcol_a = x_data .copy ()
85
- mcol_b = y_data .copy ()
84
+ one = np .ones (np .shape (y_data )[0 ], dtype = int )
86
85
87
- one = np .ones (np .shape (mcol_b )[0 ], dtype = int )
86
+ # pairing elements of one and x_data
87
+ x_train = np .column_stack ((one , x_data ))
88
88
89
- # pairing elements of one and mcol_a
90
- x_train = np .column_stack ((one , mcol_a ))
91
-
92
- return x_train , mcol_b , x_data , y_data
89
+ return x_train , x_data , y_data
93
90
94
91
95
92
def get_preds (x_train : np .ndarray , y_train : np .ndarray , tau : float ) -> np .ndarray :
@@ -108,7 +105,7 @@ def get_preds(x_train: np.ndarray, y_train: np.ndarray, tau: float) -> np.ndarra
108
105
109
106
def plot_preds (
110
107
x_train : np .ndarray ,
111
- predictions : np .ndarray ,
108
+ preds : np .ndarray ,
112
109
x_data : np .ndarray ,
113
110
y_data : np .ndarray ,
114
111
x_name : str ,
@@ -121,7 +118,7 @@ def plot_preds(
121
118
plt .scatter (x_data , y_data , color = "blue" )
122
119
plt .plot (
123
120
x_train_sorted [:, 1 ],
124
- predictions [x_train [:, 1 ].argsort (0 )],
121
+ preds [x_train [:, 1 ].argsort (0 )],
125
122
color = "yellow" ,
126
123
linewidth = 5 ,
127
124
)
@@ -136,6 +133,6 @@ def plot_preds(
136
133
137
134
doctest .testmod ()
138
135
139
- training_data_x , mcol_b , total_bill , tip = load_data ("tips" , "total_bill" , "tip" )
140
- predictions = get_preds (training_data_x , mcol_b , 0.5 )
136
+ training_data_x , total_bill , tip = load_data ("tips" , "total_bill" , "tip" )
137
+ predictions = get_preds (training_data_x , tip , 0.5 )
141
138
plot_preds (training_data_x , predictions , total_bill , tip , "total_bill" , "tip" )
0 commit comments