绘制线性可分支持向量机决策边界图 代码解析
### 绘制线性可分支持向量机决策边界图
def plot_classifer(model, X, y):
# 超参数边界
x_min = -7
x_max = 12
y_min = -12
y_max = -1
step = 0.05
# meshgrid
xx, yy = np.meshgrid(np.arange(x_min, x_max, step),
np.arange(y_min, y_max, step))
# 模型预测
z = model.predict(np.c_[xx.ravel(), yy.ravel()])
# 定义color map
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA'])
cmap_bold = ListedColormap(['#FF0000', '#003300'])
z = z.reshape(xx.shape)
plt.figure(figsize=(8, 5), dpi=96)
plt.pcolormesh(xx, yy, z, cmap=cmap_light)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold)
plt.show()
该代码用于绘制线性可分支持向量机(SVM)的决策边界图。通过在二维坐标系中可视化支持向量机的分类结果,我们可以清楚地看到决策边界是如何将不同类别的样本分开的。接下来,将详细解释代码的各个部分以及它是如何工作的。
代码详细解释
(1) 定义决策边界的绘制范围
x_min = -7
x_max = 12
y_min = -12
y_max = -1
step = 0.05
这些变量定义了绘制决策边界的坐标范围:
x_min
和x_max
:x 轴的最小和最大值,表示横向坐标的范围。y_min
和y_max
:y 轴的最小和最大值,表示纵向坐标的范围。step
:绘制网格的步长,决定了网格的密度,越小的步长会导致决策边界更加细致。
(2) 创建网格点
xx, yy = np.meshgrid(np.arange(x_min, x_max, step),
np.arange(y_min, y_max, step))
np.meshgrid()
:该函数生成了一个二维的网格,其中每个点代表输入空间的一个坐标点。通过定义网格,我们可以对整个输入空间的每个点进行分类。xx
和yy
:分别是网格的 x 和 y 坐标。
例如,如果步长为 0.05 且范围为 -7 到 12,网格的 x 坐标将是从 -7 到 12 间隔 0.05 的所有点,y 坐标将是从 -12 到 -1 间隔 0.05 的所有点。
(3) 对网格中的点进行预测
z = model.predict(np.c_[xx.ravel(), yy.ravel()])
np.c_
:将xx
和yy
坐标点展平(通过ravel()
函数),并将这些点组合为一对对的坐标点输入到模型中。model.predict()
:使用训练好的 SVM 模型对网格中的每一个点进行预测,判断它属于哪个类别。
(4) 定义颜色映射
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA'])
cmap_bold = ListedColormap(['#FF0000', '#003300'])
cmap_light
:定义浅色的颜色映射,用于背景显示不同类别区域。cmap_bold
:定义深色的颜色映射,用于显示训练数据点的颜色。
(5) 绘制决策边界和样本点
z = z.reshape(xx.shape)
plt.figure(figsize=(8, 5), dpi=96)
plt.pcolormesh(xx, yy, z, cmap=cmap_light)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold)
plt.show()
z.reshape(xx.shape)
:将预测结果z
重新调整为与xx
的形状相同,以便后续的可视化。plt.pcolormesh()
:使用pcolormesh()
函数绘制背景颜色,表示每个区域的类别。plt.scatter()
:用scatter()
函数绘制数据点,用深色显示训练数据。plt.show()
:显示绘制的图像。
代码作用总结
这段代码通过以下步骤绘制了线性可分支持向量机的决策边界:
- 定义网格:通过
np.meshgrid()
函数创建输入空间的网格点。 - 模型预测:使用
model.predict()
函数对网格中的每个点进行预测,确定该点的类别。 - 绘制决策边界:使用
plt.pcolormesh()
函数绘制每个区域的背景颜色,代表不同类别的区域。 - 绘制样本点:使用
plt.scatter()
函数绘制训练样本,显示真实的分类结果。
使用示例
假设你已经训练了一个 SVM 模型,并且有一些二维数据,那么你可以这样调用函数 plot_classifer()
:
# 假设我们有训练好的 SVM 模型和数据
svm_model = Hard_Margin_SVM()
svm_model.fit(X, y)
# 绘制决策边界
plot_classifer(svm_model, X, y)
总结
通过这段代码,你可以直观地看到 SVM 如何将样本分为两个类别,并展示它的分类边界。