当前位置: 首页 > article >正文

支持向量机 SVM

SVM 是机器学习中的一种分类方法,SVM 的目标是找到一个超平面,找到每个分类的数据点离超平面的距离最小,这些最小距离的数据点就是 Support Vector 支持向量。
在这里插入图片描述
SVM 分为线性可分和线性不可分,线性可分又分为硬距离和软距离,软距离添加了一些容错,允许某些数据点分类错误。对于线性不可分,通过核函数转为线性可分。

  • 线性可分,公式如下,确保 yi​(w⋅xi​+b)≥1
    在这里插入图片描述
  • 软距离,允许分类错误,确保 yi​(w⋅xi​+b)≥1−ξi
    在这里插入图片描述
  • 线性不可分,通过核函数将非线性函数转为线性函数,核函数可以是线性函数或者高斯函数。确保 0≤αi​≤C,α 为拉格朗日乘子。
    在这里插入图片描述

SKLearn 实现 SVM

线性可分,硬距离,完全可分。
在这里插入图片描述

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 导入sklearn模拟二分类数据生成模块
from sklearn.datasets import make_blobs
# 生成模拟二分类数据集
X, y =  make_blobs(n_samples=150, n_features=2, centers=2, cluster_std=1.2, random_state=40)
# 设置颜色参数
colors = {0:'r', 1:'g'}
# 绘制二分类数据集的散点图
plt.scatter(X[:,0], X[:,1], marker='o', c=pd.Series(y).map(colors))
plt.show();

# 导入sklearn线性SVM分类模块
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score
# 创建模型实例
clf = LinearSVC(random_state=0, tol=1e-5)
# 训练
clf.fit(X_train, y_train)
# 预测
y_pred = clf.predict(X_test)
# 计算测试集准确率
print(accuracy_score(y_test, y_pred))

线性可分,软距离,大部分可分。
在这里插入图片描述

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

mean1, mean2 = np.array([0, 2]), np.array([2, 0])
covar = np.array([[1.5, 1.0], [1.0, 1.5]])
X1 = np.random.multivariate_normal(mean1, covar, 100)
y1 = np.ones(X1.shape[0])
X2 = np.random.multivariate_normal(mean2, covar, 100)
y2 = -1 * np.ones(X2.shape[0])
X_train = np.vstack((X1[:80], X2[:80]))
y_train = np.hstack((y1[:80], y2[:80]))
X_test = np.vstack((X1[80:], X2[80:]))
y_test = np.hstack((y1[80:], y2[80:]))
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

# 设置颜色参数
colors = {1:'r', -1:'g'}
# 绘制二分类数据集的散点图
plt.scatter(X_train[:,0], X_train[:,1], marker='o', c=pd.Series(y_train).map(colors))
plt.show();

from sklearn import svm
from sklearn.metrics import accuracy_score
# 创建svm模型实例
clf = svm.SVC(kernel='linear')
# 模型拟合
clf.fit(X_train, y_train)
# 模型预测
y_pred = clf.predict(X_test)
# 计算测试集准确率
print('Accuracy of soft margin svm based on sklearn: ', 
      accuracy_score(y_test, y_pred))

线性不可分,使用 RBF / 高斯 核函数,
在这里插入图片描述

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

mean1, mean2 = np.array([-1, 2]), np.array([1, -1])
mean3, mean4 = np.array([4, -4]), np.array([-4, 4])
covar = np.array([[1.0, 0.8], [0.8, 1.0]])
X1 = np.random.multivariate_normal(mean1, covar, 50)
X1 = np.vstack((X1, np.random.multivariate_normal(mean3, covar, 50)))
y1 = np.ones(X1.shape[0])
X2 = np.random.multivariate_normal(mean2, covar, 50)
X2 = np.vstack((X2, np.random.multivariate_normal(mean4, covar, 50)))
y2 = -1 * np.ones(X2.shape[0])
X_train = np.vstack((X1[:80], X2[:80]))
y_train = np.hstack((y1[:80], y2[:80]))
X_test = np.vstack((X1[80:], X2[80:]))
y_test = np.hstack((y1[80:], y2[80:]))
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

# 设置颜色参数
colors = {1:'r', -1:'g'}
# 绘制二分类数据集的散点图
plt.scatter(X_train[:,0], X_train[:,1], marker='o', c=pd.Series(y_train).map(colors))
plt.show();

from sklearn import svm
from sklearn.metrics import accuracy_score
# 创建svm模型实例
clf = svm.SVC(kernel='rbf')
# 模型拟合
clf.fit(X_train, y_train)
# 模型预测
y_pred = clf.predict(X_test)
# 计算测试集准确率
print('Accuracy of soft margin svm based on sklearn: ', 
      accuracy_score(y_test, y_pred))

总结

本文使用 SkLearn 实现不同类型 SVM 进行数据分类,除了 SVM,线性回归也可以进行分类,可以通过以下建议进行选择。

比较标准逻辑回归 (LR)支持向量机 (SVM)
数据的线性可分性适合线性可分数据适合线性和非线性数据
可解释性低(尤其是非线性核)
计算复杂性低(速度快)高(使用RBF核时较慢)
高维数据表现良好表现良好(尤其是文本数据)
不平衡数据易于调整调整较复杂
超参数调优少(只有正则化参数)多(如 ( C ) 和 ( gamma ))
常见应用欺诈检测、医疗诊断文本分类、图像识别

http://www.kler.cn/a/393713.html

相关文章:

  • Win10本地部署大语言模型ChatGLM2-6B
  • Android基于回调的事件处理
  • ubuntu 20.04 安装docker--小白学习之路
  • The Dedicated Few (10 player)
  • 2024年度漏洞态势分析报告,需要访问自取即可!(PDF版本)
  • netplan apply报错No module named ‘netifaces‘
  • 密码学在网络安全中的应用
  • 基于ABNF语义定义的HTTP消息格式
  • 基于微信小程序的电商平台+LW示例参考
  • html文本元素
  • 第三次作业
  • 浅谈:基于三维场景的视频融合方法
  • 丹摩征文活动 | 丹摩智算平台:服务器虚拟化的璀璨明珠与实战秘籍
  • C++设计模式和编程框架两种设计元素的比较与相互关系
  • Jenkins常见问题
  • 计算机网络(5)
  • Java final关键字
  • ios swift开发--ios远程推送通知配置
  • leetcode83. Remove Duplicates from Sorted List
  • 域名绑定服务器小白教程
  • LeetCode 热题100之技巧关卡
  • Leetcode:118. 杨辉三角——Java数学法求解
  • 飞牛云fnOS本地部署WordPress个人网站并一键发布公网远程访问
  • MaxKB
  • 2024 年使用 Postman 调用 WebService 接口图文教程
  • ES6的Iterator 和 for...of 循环