|
16 | 16 | }, |
17 | 17 | { |
18 | 18 | "cell_type": "code", |
19 | | - "execution_count": 2, |
| 19 | + "execution_count": 1, |
20 | 20 | "metadata": {}, |
21 | 21 | "outputs": [], |
22 | 22 | "source": [ |
23 | 23 | "from sklearn.datasets import load_diabetes\n", |
24 | 24 | "from sklearn.linear_model import Ridge\n", |
25 | 25 | "from sklearn.metrics import mean_squared_error\n", |
26 | 26 | "from sklearn.model_selection import train_test_split\n", |
27 | | - "import joblib" |
| 27 | + "import joblib\n", |
| 28 | + "import pandas as pd" |
28 | 29 | ] |
29 | 30 | }, |
30 | 31 | { |
|
36 | 37 | }, |
37 | 38 | { |
38 | 39 | "cell_type": "code", |
39 | | - "execution_count": 3, |
| 40 | + "execution_count": 6, |
40 | 41 | "metadata": {}, |
41 | 42 | "outputs": [], |
42 | 43 | "source": [ |
43 | | - "X, y = load_diabetes(return_X_y=True)" |
| 44 | + "sample_data = load_diabetes()\n", |
| 45 | + "\n", |
| 46 | + "df = pd.DataFrame(\n", |
| 47 | + " data=sample_data.data,\n", |
| 48 | + " columns=sample_data.feature_names)\n", |
| 49 | + "df['Y'] = sample_data.target" |
44 | 50 | ] |
45 | 51 | }, |
46 | 52 | { |
47 | 53 | "cell_type": "code", |
48 | | - "execution_count": 4, |
| 54 | + "execution_count": 7, |
49 | 55 | "metadata": {}, |
50 | 56 | "outputs": [ |
51 | 57 | { |
|
57 | 63 | } |
58 | 64 | ], |
59 | 65 | "source": [ |
60 | | - "print(X.shape)" |
61 | | - ] |
62 | | - }, |
63 | | - { |
64 | | - "cell_type": "code", |
65 | | - "execution_count": 5, |
66 | | - "metadata": {}, |
67 | | - "outputs": [ |
68 | | - { |
69 | | - "name": "stdout", |
70 | | - "output_type": "stream", |
71 | | - "text": [ |
72 | | - "(442,)\n" |
73 | | - ] |
74 | | - } |
75 | | - ], |
76 | | - "source": [ |
77 | | - "print(y.shape)" |
| 66 | + "print(df.shape)" |
78 | 67 | ] |
79 | 68 | }, |
80 | 69 | { |
81 | 70 | "cell_type": "code", |
82 | | - "execution_count": 8, |
| 71 | + "execution_count": 11, |
83 | 72 | "metadata": {}, |
84 | 73 | "outputs": [ |
85 | 74 | { |
|
103 | 92 | " <thead>\n", |
104 | 93 | " <tr style=\"text-align: right;\">\n", |
105 | 94 | " <th></th>\n", |
106 | | - " <th>0</th>\n", |
107 | | - " <th>1</th>\n", |
108 | | - " <th>2</th>\n", |
109 | | - " <th>3</th>\n", |
110 | | - " <th>4</th>\n", |
111 | | - " <th>5</th>\n", |
112 | | - " <th>6</th>\n", |
113 | | - " <th>7</th>\n", |
114 | | - " <th>8</th>\n", |
115 | | - " <th>9</th>\n", |
| 95 | + " <th>age</th>\n", |
| 96 | + " <th>sex</th>\n", |
| 97 | + " <th>bmi</th>\n", |
| 98 | + " <th>bp</th>\n", |
| 99 | + " <th>s1</th>\n", |
| 100 | + " <th>s2</th>\n", |
| 101 | + " <th>s3</th>\n", |
| 102 | + " <th>s4</th>\n", |
| 103 | + " <th>s5</th>\n", |
| 104 | + " <th>s6</th>\n", |
| 105 | + " <th>Y</th>\n", |
116 | 106 | " </tr>\n", |
117 | 107 | " </thead>\n", |
118 | 108 | " <tbody>\n", |
|
128 | 118 | " <td>4.420000e+02</td>\n", |
129 | 119 | " <td>4.420000e+02</td>\n", |
130 | 120 | " <td>4.420000e+02</td>\n", |
| 121 | + " <td>442.000000</td>\n", |
131 | 122 | " </tr>\n", |
132 | 123 | " <tr>\n", |
133 | 124 | " <td>mean</td>\n", |
134 | | - " <td>-3.639623e-16</td>\n", |
135 | | - " <td>1.309912e-16</td>\n", |
136 | | - " <td>-8.013951e-16</td>\n", |
137 | | - " <td>1.289818e-16</td>\n", |
138 | | - " <td>-9.042540e-17</td>\n", |
139 | | - " <td>1.301121e-16</td>\n", |
140 | | - " <td>-4.563971e-16</td>\n", |
141 | | - " <td>3.863174e-16</td>\n", |
142 | | - " <td>-3.848103e-16</td>\n", |
143 | | - " <td>-3.398488e-16</td>\n", |
| 125 | + " <td>-3.634285e-16</td>\n", |
| 126 | + " <td>1.308343e-16</td>\n", |
| 127 | + " <td>-8.045349e-16</td>\n", |
| 128 | + " <td>1.281655e-16</td>\n", |
| 129 | + " <td>-8.835316e-17</td>\n", |
| 130 | + " <td>1.327024e-16</td>\n", |
| 131 | + " <td>-4.574646e-16</td>\n", |
| 132 | + " <td>3.777301e-16</td>\n", |
| 133 | + " <td>-3.830854e-16</td>\n", |
| 134 | + " <td>-3.412882e-16</td>\n", |
| 135 | + " <td>152.133484</td>\n", |
144 | 136 | " </tr>\n", |
145 | 137 | " <tr>\n", |
146 | 138 | " <td>std</td>\n", |
|
154 | 146 | " <td>4.761905e-02</td>\n", |
155 | 147 | " <td>4.761905e-02</td>\n", |
156 | 148 | " <td>4.761905e-02</td>\n", |
| 149 | + " <td>77.093005</td>\n", |
157 | 150 | " </tr>\n", |
158 | 151 | " <tr>\n", |
159 | 152 | " <td>min</td>\n", |
|
167 | 160 | " <td>-7.639450e-02</td>\n", |
168 | 161 | " <td>-1.260974e-01</td>\n", |
169 | 162 | " <td>-1.377672e-01</td>\n", |
| 163 | + " <td>25.000000</td>\n", |
170 | 164 | " </tr>\n", |
171 | 165 | " <tr>\n", |
172 | 166 | " <td>25%</td>\n", |
|
180 | 174 | " <td>-3.949338e-02</td>\n", |
181 | 175 | " <td>-3.324879e-02</td>\n", |
182 | 176 | " <td>-3.317903e-02</td>\n", |
| 177 | + " <td>87.000000</td>\n", |
183 | 178 | " </tr>\n", |
184 | 179 | " <tr>\n", |
185 | 180 | " <td>50%</td>\n", |
|
193 | 188 | " <td>-2.592262e-03</td>\n", |
194 | 189 | " <td>-1.947634e-03</td>\n", |
195 | 190 | " <td>-1.077698e-03</td>\n", |
| 191 | + " <td>140.500000</td>\n", |
196 | 192 | " </tr>\n", |
197 | 193 | " <tr>\n", |
198 | 194 | " <td>75%</td>\n", |
|
206 | 202 | " <td>3.430886e-02</td>\n", |
207 | 203 | " <td>3.243323e-02</td>\n", |
208 | 204 | " <td>2.791705e-02</td>\n", |
| 205 | + " <td>211.500000</td>\n", |
209 | 206 | " </tr>\n", |
210 | 207 | " <tr>\n", |
211 | 208 | " <td>max</td>\n", |
|
219 | 216 | " <td>1.852344e-01</td>\n", |
220 | 217 | " <td>1.335990e-01</td>\n", |
221 | 218 | " <td>1.356118e-01</td>\n", |
| 219 | + " <td>346.000000</td>\n", |
222 | 220 | " </tr>\n", |
223 | 221 | " </tbody>\n", |
224 | 222 | "</table>\n", |
225 | 223 | "</div>" |
226 | 224 | ], |
227 | 225 | "text/plain": [ |
228 | | - " 0 1 2 3 4 \\\n", |
| 226 | + " age sex bmi bp s1 \\\n", |
229 | 227 | "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n", |
230 | | - "mean -3.639623e-16 1.309912e-16 -8.013951e-16 1.289818e-16 -9.042540e-17 \n", |
| 228 | + "mean -3.634285e-16 1.308343e-16 -8.045349e-16 1.281655e-16 -8.835316e-17 \n", |
231 | 229 | "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n", |
232 | 230 | "min -1.072256e-01 -4.464164e-02 -9.027530e-02 -1.123996e-01 -1.267807e-01 \n", |
233 | 231 | "25% -3.729927e-02 -4.464164e-02 -3.422907e-02 -3.665645e-02 -3.424784e-02 \n", |
234 | 232 | "50% 5.383060e-03 -4.464164e-02 -7.283766e-03 -5.670611e-03 -4.320866e-03 \n", |
235 | 233 | "75% 3.807591e-02 5.068012e-02 3.124802e-02 3.564384e-02 2.835801e-02 \n", |
236 | 234 | "max 1.107267e-01 5.068012e-02 1.705552e-01 1.320442e-01 1.539137e-01 \n", |
237 | 235 | "\n", |
238 | | - " 5 6 7 8 9 \n", |
239 | | - "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n", |
240 | | - "mean 1.301121e-16 -4.563971e-16 3.863174e-16 -3.848103e-16 -3.398488e-16 \n", |
241 | | - "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n", |
242 | | - "min -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01 \n", |
243 | | - "25% -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02 \n", |
244 | | - "50% -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03 \n", |
245 | | - "75% 2.984439e-02 2.931150e-02 3.430886e-02 3.243323e-02 2.791705e-02 \n", |
246 | | - "max 1.987880e-01 1.811791e-01 1.852344e-01 1.335990e-01 1.356118e-01 " |
| 236 | + " s2 s3 s4 s5 s6 \\\n", |
| 237 | + "count 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 4.420000e+02 \n", |
| 238 | + "mean 1.327024e-16 -4.574646e-16 3.777301e-16 -3.830854e-16 -3.412882e-16 \n", |
| 239 | + "std 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 4.761905e-02 \n", |
| 240 | + "min -1.156131e-01 -1.023071e-01 -7.639450e-02 -1.260974e-01 -1.377672e-01 \n", |
| 241 | + "25% -3.035840e-02 -3.511716e-02 -3.949338e-02 -3.324879e-02 -3.317903e-02 \n", |
| 242 | + "50% -3.819065e-03 -6.584468e-03 -2.592262e-03 -1.947634e-03 -1.077698e-03 \n", |
| 243 | + "75% 2.984439e-02 2.931150e-02 3.430886e-02 3.243323e-02 2.791705e-02 \n", |
| 244 | + "max 1.987880e-01 1.811791e-01 1.852344e-01 1.335990e-01 1.356118e-01 \n", |
| 245 | + "\n", |
| 246 | + " Y \n", |
| 247 | + "count 442.000000 \n", |
| 248 | + "mean 152.133484 \n", |
| 249 | + "std 77.093005 \n", |
| 250 | + "min 25.000000 \n", |
| 251 | + "25% 87.000000 \n", |
| 252 | + "50% 140.500000 \n", |
| 253 | + "75% 211.500000 \n", |
| 254 | + "max 346.000000 " |
247 | 255 | ] |
248 | 256 | }, |
249 | | - "execution_count": 8, |
| 257 | + "execution_count": 11, |
250 | 258 | "metadata": {}, |
251 | 259 | "output_type": "execute_result" |
252 | 260 | } |
253 | 261 | ], |
254 | 262 | "source": [ |
255 | | - "import pandas as pd\n", |
256 | | - "features = pd.DataFrame(X)\n", |
257 | | - "features.describe()" |
| 263 | + "# All data in a single dataframe\n", |
| 264 | + "df.describe()" |
258 | 265 | ] |
259 | 266 | }, |
260 | 267 | { |
|
266 | 273 | }, |
267 | 274 | { |
268 | 275 | "cell_type": "code", |
269 | | - "execution_count": 3, |
| 276 | + "execution_count": 12, |
270 | 277 | "metadata": {}, |
271 | 278 | "outputs": [], |
272 | 279 | "source": [ |
273 | | - "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)\n", |
| 280 | + "X = df.drop('Y', axis=1).values\n", |
| 281 | + "y = df['Y'].values\n", |
| 282 | + "\n", |
| 283 | + "X_train, X_test, y_train, y_test = train_test_split(\n", |
| 284 | + " X, y, test_size=0.2, random_state=0)\n", |
274 | 285 | "data = {\"train\": {\"X\": X_train, \"y\": y_train},\n", |
275 | 286 | " \"test\": {\"X\": X_test, \"y\": y_test}}" |
276 | 287 | ] |
|
284 | 295 | }, |
285 | 296 | { |
286 | 297 | "cell_type": "code", |
287 | | - "execution_count": 4, |
| 298 | + "execution_count": 16, |
288 | 299 | "metadata": {}, |
289 | 300 | "outputs": [ |
290 | 301 | { |
|
294 | 305 | " normalize=False, random_state=None, solver='auto', tol=0.001)" |
295 | 306 | ] |
296 | 307 | }, |
297 | | - "execution_count": 4, |
| 308 | + "execution_count": 16, |
298 | 309 | "metadata": {}, |
299 | 310 | "output_type": "execute_result" |
300 | 311 | } |
301 | 312 | ], |
302 | 313 | "source": [ |
303 | | - "alpha = 0.5\n", |
| 314 | + "# experiment parameters\n", |
| 315 | + "args = {\n", |
| 316 | + " \"alpha\": 0.5\n", |
| 317 | + "}\n", |
304 | 318 | "\n", |
305 | | - "reg = Ridge(alpha=alpha)\n", |
306 | | - "reg.fit(data[\"train\"][\"X\"], data[\"train\"][\"y\"])" |
| 319 | + "reg_model = Ridge(**args)\n", |
| 320 | + "reg_model.fit(data[\"train\"][\"X\"], data[\"train\"][\"y\"])" |
307 | 321 | ] |
308 | 322 | }, |
309 | 323 | { |
|
315 | 329 | }, |
316 | 330 | { |
317 | 331 | "cell_type": "code", |
318 | | - "execution_count": 6, |
| 332 | + "execution_count": 18, |
319 | 333 | "metadata": {}, |
320 | 334 | "outputs": [ |
321 | 335 | { |
322 | 336 | "name": "stdout", |
323 | 337 | "output_type": "stream", |
324 | 338 | "text": [ |
325 | | - "mse: 3298.9096058070622\n" |
| 339 | + "{'mse': 3298.9096058070622}\n" |
326 | 340 | ] |
327 | 341 | } |
328 | 342 | ], |
329 | 343 | "source": [ |
330 | | - "preds = reg.predict(data[\"test\"][\"X\"])\n", |
331 | | - "print(\"mse: \", mean_squared_error(preds, y_test))" |
| 344 | + "preds = reg_model.predict(data[\"test\"][\"X\"])\n", |
| 345 | + "mse = mean_squared_error(preds, y_test)\n", |
| 346 | + "metrics = {\"mse\": mse}\n", |
| 347 | + "print(metrics)" |
332 | 348 | ] |
333 | 349 | }, |
334 | 350 | { |
|
363 | 379 | ], |
364 | 380 | "metadata": { |
365 | 381 | "kernelspec": { |
366 | | - "display_name": "Python (storedna)", |
| 382 | + "display_name": "Python 3", |
367 | 383 | "language": "python", |
368 | | - "name": "storedna" |
| 384 | + "name": "python3" |
369 | 385 | }, |
370 | 386 | "language_info": { |
371 | 387 | "codemirror_mode": { |
|
377 | 393 | "name": "python", |
378 | 394 | "nbconvert_exporter": "python", |
379 | 395 | "pygments_lexer": "ipython3", |
380 | | - "version": "3.6.9" |
| 396 | + "version": "3.7.4" |
381 | 397 | } |
382 | 398 | }, |
383 | 399 | "nbformat": 4, |
|
0 commit comments