当前位置: 首页 > article >正文

绘制线性可分支持向量机决策边界图 代码解析

### 绘制线性可分支持向量机决策边界图
def plot_classifer(model, X, y):
    # 超参数边界
    x_min = -7
    x_max = 12
    y_min = -12
    y_max = -1
    step = 0.05
    # meshgrid
    xx, yy = np.meshgrid(np.arange(x_min, x_max, step),
                         np.arange(y_min, y_max, step))
    # 模型预测
    z = model.predict(np.c_[xx.ravel(), yy.ravel()])

    # 定义color map
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA'])
    cmap_bold = ListedColormap(['#FF0000', '#003300'])
    z = z.reshape(xx.shape)

    plt.figure(figsize=(8, 5), dpi=96)
    plt.pcolormesh(xx, yy, z, cmap=cmap_light)
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold)
    plt.show()

该代码用于绘制线性可分支持向量机(SVM)的决策边界图。通过在二维坐标系中可视化支持向量机的分类结果,我们可以清楚地看到决策边界是如何将不同类别的样本分开的。接下来,将详细解释代码的各个部分以及它是如何工作的。

代码详细解释

(1) 定义决策边界的绘制范围
x_min = -7
x_max = 12
y_min = -12
y_max = -1
step = 0.05

这些变量定义了绘制决策边界的坐标范围:

  • x_minx_max:x 轴的最小和最大值,表示横向坐标的范围。
  • y_miny_max:y 轴的最小和最大值,表示纵向坐标的范围。
  • step:绘制网格的步长,决定了网格的密度,越小的步长会导致决策边界更加细致。
(2) 创建网格点
xx, yy = np.meshgrid(np.arange(x_min, x_max, step),
                     np.arange(y_min, y_max, step))
  • np.meshgrid():该函数生成了一个二维的网格,其中每个点代表输入空间的一个坐标点。通过定义网格,我们可以对整个输入空间的每个点进行分类。
  • xxyy:分别是网格的 x 和 y 坐标。

例如,如果步长为 0.05 且范围为 -7 到 12,网格的 x 坐标将是从 -7 到 12 间隔 0.05 的所有点,y 坐标将是从 -12 到 -1 间隔 0.05 的所有点。

(3) 对网格中的点进行预测
z = model.predict(np.c_[xx.ravel(), yy.ravel()])
  • np.c_:将 xxyy 坐标点展平(通过 ravel() 函数),并将这些点组合为一对对的坐标点输入到模型中。
  • model.predict():使用训练好的 SVM 模型对网格中的每一个点进行预测,判断它属于哪个类别。
(4) 定义颜色映射
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA'])
cmap_bold = ListedColormap(['#FF0000', '#003300'])
  • cmap_light:定义浅色的颜色映射,用于背景显示不同类别区域。
  • cmap_bold:定义深色的颜色映射,用于显示训练数据点的颜色。
(5) 绘制决策边界和样本点
z = z.reshape(xx.shape)
plt.figure(figsize=(8, 5), dpi=96)
plt.pcolormesh(xx, yy, z, cmap=cmap_light)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold)
plt.show()
  • z.reshape(xx.shape):将预测结果 z 重新调整为与 xx 的形状相同,以便后续的可视化。
  • plt.pcolormesh():使用 pcolormesh() 函数绘制背景颜色,表示每个区域的类别。
  • plt.scatter():用 scatter() 函数绘制数据点,用深色显示训练数据。
  • plt.show():显示绘制的图像。

代码作用总结

这段代码通过以下步骤绘制了线性可分支持向量机的决策边界:

  1. 定义网格:通过 np.meshgrid() 函数创建输入空间的网格点。
  2. 模型预测:使用 model.predict() 函数对网格中的每个点进行预测,确定该点的类别。
  3. 绘制决策边界:使用 plt.pcolormesh() 函数绘制每个区域的背景颜色,代表不同类别的区域。
  4. 绘制样本点:使用 plt.scatter() 函数绘制训练样本,显示真实的分类结果。

使用示例

假设你已经训练了一个 SVM 模型,并且有一些二维数据,那么你可以这样调用函数 plot_classifer()

# 假设我们有训练好的 SVM 模型和数据
svm_model = Hard_Margin_SVM()
svm_model.fit(X, y)

# 绘制决策边界
plot_classifer(svm_model, X, y)

总结

通过这段代码,你可以直观地看到 SVM 如何将样本分为两个类别,并展示它的分类边界。


http://www.kler.cn/a/371672.html

相关文章:

  • 推荐系统中的AB测试
  • docker占用磁盘过多问题
  • 微信小程序如何实现地图轨迹回放?
  • 线性代数(1)——线性方程组的几何意义
  • 在IDEA中运行Mybatis后发现取出的password值为null
  • 服务器数据恢复—异常断电导致服务器挂载分区无法访问的数据恢复案例
  • 使用Docker Compose简化微服务部署
  • 5G中NG接口
  • Cisco Packet Tracer 8.0 路由器静态路由配置
  • 设计模式---模版模式
  • 【机器学习】过拟合与欠拟合
  • 用哈希表封装unordered_map与unordered_set
  • sklearn机器学习实战
  • C++ 二叉树进阶:相关习题解析
  • C#实现与Windows服务的交互与控制
  • flinksql-Queries查询相关实战
  • 算法篇——动态规划最终篇 (js版)
  • uniapp position: fixed 兼容性不显示问题
  • Python Flask 数据库开发
  • Modbus TCP报文协议(ModbusTCP)
  • H5底部输入框点击弹起来的时候被软键盘遮挡bug
  • QT编译报错:-1: error: cannot find -lGL
  • 淘宝商品评价API的获取与应用
  • Prometheus自定义PostgreSQL监控指标
  • 直接删除Github上的文件
  • [flask] flask-mail邮件发送