|
| 1 | +--- |
| 2 | +description: How to visualize k-Nearest Neighbors (kNN) created using scikit-learn |
| 3 | + in Python with Plotly. |
| 4 | +display_as: ai_ml |
| 5 | +language: python |
| 6 | +layout: base |
| 7 | +name: K-Nearest Neighbors (kNN) Classification |
| 8 | +order: 1 |
| 9 | +page_type: example_index |
| 10 | +permalink: python/knn/ |
| 11 | +redirect_from: python/machine-learning-tutorials/ |
| 12 | +thumbnail: thumbnail/line-and-scatter.jpg |
| 13 | +--- |
| 14 | + |
| 15 | +## Basic Binary Classification with `plotly.express` |
| 16 | + |
| 17 | +```python |
| 18 | +import numpy as np |
| 19 | +import plotly.express as px |
| 20 | +import plotly.graph_objects as go |
| 21 | +from sklearn.datasets import make_moons |
| 22 | +from sklearn.neighbors import KNeighborsClassifier |
| 23 | + |
| 24 | +X, y = make_moons(noise=0.3, random_state=0) |
| 25 | +X_test, _ = make_moons(noise=0.3, random_state=1) |
| 26 | + |
| 27 | +clf = KNeighborsClassifier(15) |
| 28 | +clf.fit(X, y.astype(str)) # Fit on training set |
| 29 | +y_pred = clf.predict(X_test) # Predict on new data |
| 30 | + |
| 31 | +fig = px.scatter(x=X_test[:, 0], y=X_test[:, 1], color=y_pred, labels={'color': 'predicted'}) |
| 32 | +fig.update_traces(marker_size=10) |
| 33 | +fig.show() |
| 34 | +``` |
| 35 | + |
| 36 | +## Visualize Binary Prediction Scores |
| 37 | + |
| 38 | +```python |
| 39 | +import numpy as np |
| 40 | +import plotly.express as px |
| 41 | +import plotly.graph_objects as go |
| 42 | +from sklearn.datasets import make_classification |
| 43 | +from sklearn.neighbors import KNeighborsClassifier |
| 44 | + |
| 45 | +X, y = make_classification(n_features=2, n_redundant=0, random_state=0) |
| 46 | +X_test, _ = make_classification(n_features=2, n_redundant=0, random_state=1) |
| 47 | + |
| 48 | +clf = KNeighborsClassifier(15) |
| 49 | +clf.fit(X, y) # Fit on training set |
| 50 | +y_score = clf.predict_proba(X_test)[:, 1] # Predict on new data |
| 51 | + |
| 52 | +fig = px.scatter(x=X_test[:, 0], y=X_test[:, 1], color=y_score, labels={'color': 'score'}) |
| 53 | +fig.update_traces(marker_size=10) |
| 54 | +fig.show() |
| 55 | +``` |
| 56 | + |
| 57 | +## Probability Estimates with `go.Contour` |
| 58 | + |
| 59 | +```python |
| 60 | +import numpy as np |
| 61 | +import plotly.express as px |
| 62 | +import plotly.graph_objects as go |
| 63 | +from sklearn.datasets import make_moons |
| 64 | +from sklearn.neighbors import KNeighborsClassifier |
| 65 | + |
| 66 | +mesh_size = .02 |
| 67 | +margin = 1 |
| 68 | + |
| 69 | +X, y = make_moons(noise=0.3, random_state=0) |
| 70 | + |
| 71 | +# Create a mesh grid on which we will run our model |
| 72 | +x_min, x_max = X[:, 0].min() - margin, X[:, 0].max() + margin |
| 73 | +y_min, y_max = X[:, 1].min() - margin, X[:, 1].max() + margin |
| 74 | +xrange = np.arange(x_min, x_max, mesh_size) |
| 75 | +yrange = np.arange(y_min, y_max, mesh_size) |
| 76 | +xx, yy = np.meshgrid(xrange, yrange) |
| 77 | + |
| 78 | +# Create classifier, run predictions on grid |
| 79 | +clf = KNeighborsClassifier(15, weights='uniform') |
| 80 | +clf.fit(X, y) |
| 81 | +Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] |
| 82 | +Z = Z.reshape(xx.shape) |
| 83 | + |
| 84 | +fig = px.scatter(X, x=0, y=1, color=y.astype(str), labels={'0':'', '1':''}) |
| 85 | +fig.update_traces(marker_size=10, marker_line_width=1) |
| 86 | +fig.add_trace( |
| 87 | + go.Contour( |
| 88 | + x=xrange, |
| 89 | + y=yrange, |
| 90 | + z=Z, |
| 91 | + showscale=False, |
| 92 | + colorscale=['Blue', 'Red'], |
| 93 | + opacity=0.4, |
| 94 | + name='Confidence' |
| 95 | + ) |
| 96 | +) |
| 97 | +fig.show() |
| 98 | +``` |
| 99 | + |
| 100 | +## Multi-class prediction confidence with `go.Heatmap` |
| 101 | + |
| 102 | +```python |
| 103 | +import numpy as np |
| 104 | +import plotly.express as px |
| 105 | +import plotly.graph_objects as go |
| 106 | +from sklearn.neighbors import KNeighborsClassifier |
| 107 | + |
| 108 | +mesh_size = .02 |
| 109 | +margin = 1 |
| 110 | + |
| 111 | +# We will use the iris data, which is included in px |
| 112 | +df = px.data.iris() |
| 113 | +X = df[['sepal_length', 'sepal_width']] |
| 114 | +y = df.species_id |
| 115 | + |
| 116 | +# Create a mesh grid on which we will run our model |
| 117 | +l_min, l_max = df.sepal_length.min() - margin, df.sepal_length.max() + margin |
| 118 | +w_min, w_max = df.sepal_width.min() - margin, df.sepal_width.max() + margin |
| 119 | +lrange = np.arange(l_min, l_max, mesh_size) |
| 120 | +wrange = np.arange(w_min, w_max, mesh_size) |
| 121 | +ll, ww = np.meshgrid(lrange, wrange) |
| 122 | + |
| 123 | +# Create classifier, run predictions on grid |
| 124 | +clf = KNeighborsClassifier(15, weights='distance') |
| 125 | +clf.fit(X, y) |
| 126 | +Z = clf.predict(np.c_[ll.ravel(), ww.ravel()]) |
| 127 | +Z = Z.reshape(ll.shape) |
| 128 | +proba = clf.predict_proba(np.c_[ll.ravel(), ww.ravel()]) |
| 129 | +proba = proba.reshape(ll.shape + (3,)) |
| 130 | + |
| 131 | +fig = px.scatter(df, x='sepal_length', y='sepal_width', color='species') |
| 132 | +fig.update_traces(marker_size=10, marker_line_width=1) |
| 133 | +fig.add_trace( |
| 134 | + go.Heatmap( |
| 135 | + x=lrange, |
| 136 | + y=wrange, |
| 137 | + z=Z, |
| 138 | + showscale=False, |
| 139 | + colorscale=[[0.0, 'blue'], [0.5, 'red'], [1.0, 'green']], |
| 140 | + opacity=0.25, |
| 141 | + customdata=proba, |
| 142 | + hovertemplate=( |
| 143 | + 'sepal length: %{x} <br>' |
| 144 | + 'sepal width: %{y} <br>' |
| 145 | + 'p(setosa): %{customdata[0]:.3f}<br>' |
| 146 | + 'p(versicolor): %{customdata[1]:.3f}<br>' |
| 147 | + 'p(virginica): %{customdata[2]:.3f}<extra></extra>' |
| 148 | + ) |
| 149 | + ) |
| 150 | +) |
| 151 | +fig.show() |
| 152 | +``` |
| 153 | + |
| 154 | +## 3D Classification with `px.scatter_3d` |
| 155 | + |
| 156 | +```python |
| 157 | +import numpy as np |
| 158 | +import plotly.express as px |
| 159 | +import plotly.graph_objects as go |
| 160 | +from sklearn.neighbors import KNeighborsClassifier |
| 161 | +from sklearn.model_selection import train_test_split |
| 162 | + |
| 163 | +df = px.data.iris() |
| 164 | +features = ["sepal_width", "sepal_length", "petal_width"] |
| 165 | + |
| 166 | +X = df[features] |
| 167 | +y = df.species |
| 168 | +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0) |
| 169 | + |
| 170 | +# Create classifier, run predictions on grid |
| 171 | +clf = KNeighborsClassifier(15, weights='distance') |
| 172 | +clf.fit(X_train, y_train) |
| 173 | +y_pred = clf.predict(X_test) |
| 174 | +y_score = clf.predict_proba(X_test) |
| 175 | +y_score = np.around(y_score.max(axis=1), 4) |
| 176 | + |
| 177 | +fig = px.scatter_3d( |
| 178 | + X_test, |
| 179 | + x='sepal_length', |
| 180 | + y='sepal_width', |
| 181 | + z='petal_width', |
| 182 | + symbol=y_pred, |
| 183 | + color=y_score, |
| 184 | + labels={'symbol': 'prediction', 'color': 'score'} |
| 185 | +) |
| 186 | +fig.update_layout(legend=dict(x=0, y=0)) |
| 187 | +fig.show() |
| 188 | +``` |
| 189 | + |
| 190 | +## High Dimension Visualization with `px.scatter_matrix` |
| 191 | + |
| 192 | +If you need to visualize classifications that go beyond 3D, you can use the [scatter plot matrix](https://plot.ly/python/splom/). |
| 193 | + |
| 194 | +```python |
| 195 | +import numpy as np |
| 196 | +import plotly.express as px |
| 197 | +import plotly.graph_objects as go |
| 198 | +from sklearn.neighbors import KNeighborsClassifier |
| 199 | +from sklearn.model_selection import train_test_split |
| 200 | + |
| 201 | +df = px.data.iris() |
| 202 | +features = ["sepal_width", "sepal_length", "petal_width", "petal_length"] |
| 203 | + |
| 204 | +X = df[features] |
| 205 | +y = df.species |
| 206 | +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0) |
| 207 | + |
| 208 | +# Create classifier, run predictions on grid |
| 209 | +clf = KNeighborsClassifier(15, weights='distance') |
| 210 | +clf.fit(X_train, y_train) |
| 211 | +y_pred = clf.predict(X_test) |
| 212 | + |
| 213 | +fig = px.scatter_matrix(X_test, dimensions=features, color=y_pred, labels={'color': 'prediction'}) |
| 214 | +fig.show() |
| 215 | +``` |
| 216 | + |
| 217 | +### Reference |
| 218 | + |
| 219 | +Learn more about `px`, `go.Contour`, and `go.Heatmap` here: |
| 220 | +* https://plot.ly/python/plotly-express/ |
| 221 | +* https://plot.ly/python/heatmaps/ |
| 222 | +* https://plot.ly/python/contour-plots/ |
| 223 | +* https://plot.ly/python/3d-scatter-plots/ |
| 224 | +* https://plot.ly/python/splom/ |
| 225 | + |
| 226 | +This tutorial was inspired by amazing examples from the official scikit-learn docs: |
| 227 | +* https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html |
| 228 | +* https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html |
| 229 | +* https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html |
0 commit comments