This repository has been archived by the owner on Oct 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualization.py
56 lines (46 loc) · 1.51 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Copyright Jeremy Nation <[email protected]>.
Licensed under the MIT license.
Almost entirely copied from code created by Sebastian Raschka, also licensed under the MIT license.
"""
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import numpy as np
def plot_decision_regions(X, y, classifier, resolution=0.02, test_index=None):
markers = ('s', 'x', 'o', '^', 'v')
colors = ('red', 'blue', 'lightgreen', 'gray', 'cyan')
cmap = ListedColormap(colors[:len(np.unique(y))])
x1_min = X[:, 0].min() - 1
x1_max = X[:, 0].max() + 1
x2_min = X[:, 1].min() - 1
x2_max = X[:, 1].max() + 1
xx1, xx2 = np.meshgrid(
np.arange(x1_min, x1_max, resolution),
np.arange(x2_min, x2_max, resolution),
)
Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
Z = Z.reshape(xx1.shape)
plt.contourf(xx1, xx2, Z, alpha=0.4, cmap=cmap)
plt.xlim(xx1.min(), xx1.max())
plt.ylim(xx2.min(), xx2.max())
for index, class_ in enumerate(np.unique(y)):
plt.scatter(
x=X[y == class_, 0],
y=X[y == class_, 1],
alpha=0.8,
c=cmap(index),
marker=markers[index],
label=class_,
)
if test_index is not None:
X_test = X[test_index, :]
plt.scatter(
X_test[:, 0],
X_test[:, 1],
s=55,
c='',
marker='o',
alpha=1.0,
linewidths=1,
label='test set',
)