当前位置: 首页 > article >正文

大数据-197 数据挖掘 机器学习理论 - scikit-learn 泛化能力 交叉验证

点一下关注吧!!!非常感谢!!持续更新!!!

目前已经更新到了:

  • Hadoop(已更完)
  • HDFS(已更完)
  • MapReduce(已更完)
  • Hive(已更完)
  • Flume(已更完)
  • Sqoop(已更完)
  • Zookeeper(已更完)
  • HBase(已更完)
  • Redis (已更完)
  • Kafka(已更完)
  • Spark(已更完)
  • Flink(已更完)
  • ClickHouse(已更完)
  • Kudu(已更完)
  • Druid(已更完)
  • Kylin(已更完)
  • Elasticsearch(已更完)
  • DataX(已更完)
  • Tez(已更完)
  • 数据挖掘(正在更新…)

章节内容

上节我们完成了如下的内容:

  • scikit-learn 算法库实现
  • 案例1 红酒
  • 案例 2 乳腺癌

在这里插入图片描述

交叉验证

确定了 K 之后,我们还能够发现一件事情,每次运行的时候学习曲线都在变化,模型的效果有时好有时坏,这是为什么?
实际上,这都是由于【训练集】和【测试集】的划分不同造成的,模型每次都使用不同的训练集进行训练,不同的测试集进行测试,自然也就有不同的结果。
在业务中,我们训练数据往往都是以往已经有的历史数据,但我们的测试数据却是新进入系统的数据,我们追求模型的效果,但是追求的是模型在未知数据集上的效果,在陌生的数据集上表现的能力被称为泛化能力,即我们追求的是模型的泛化能力。

泛化能力

我们在进行学习的时候,通常会将一个样本集分化为【训练集】和【测试集】,其中训练集用于模型的学习和训练,而后测试集通常用于评估训练好的模型对于数据的预测性的评估。

  • 训练误差代表模型在训练集上的错分样本比率。
  • 测试误差代表模型在测试集上的错分样本比率。

训练误差的大小,用来判断给定问题是不是一个容易学习的问题,测试误差反应了模型对未知数据的预测能力,测试误差小的学习方法具有很好的预测能力,如果得到的训练集和测试集没有交集,通常将此预测能力称为泛化能力(generalization ability)。
我们认为,如果模型在一套训练集和数据集上表现优秀,那说明不了问题,只能在众多不同的训练集和测试集上都表现优秀,模型才是一个稳定的模型,模型才是具有真正意义上的泛化能力。
为此,机器学习领域有发挥神作用的技能:【交叉验证】,来帮助我们认识模型。

k折交叉验证

最常用的交叉验证是 K 折交叉验证,我们知道训练集和测试集的划分会干扰模型的结果,因此用交叉验证 N 次的结果求出均值,是对模型效果的一个更好的度量。

在这里插入图片描述

带交叉验证的学习曲线

对于带交叉验证的学习曲线,我们需要观察的就不仅仅是最高的准确率了,而是准确率高且方差还相对较小的点,这样的点泛化能力才是最强的,在交叉验证+学习曲线的作用下,我们选出的超参数能够保证更好的泛化能力。

from sklearn.model_selection import cross_val_score as CVS
Xtrain,Xtest,Ytrain,Ytest = train_test_split(X,y,test_size=0.2,random_state=420)
clf = KNeighborsClassifier(n_neighbors=8)
#训练集对折6次,一共6个预测率输出
cvresult = CVS(clf,Xtrain,Ytrain,cv=6) 
#每次交叉验证运行时估算器得分的数组
cvresult 

执行结果如下:
在这里插入图片描述
查看均值、方差

# 均值:查看模型的平均效果
cvresult.mean()
# 方差:查看模型是否稳定
cvresult.var()

执行结果如下所示:
在这里插入图片描述
绘制图片进行查看:

score = []
var = []
#设置不同的k值,从1到19都看看
krange=range(1,20) 
for i in krange:
    clf = KNeighborsClassifier(n_neighbors=i)
    cvresult = CVS(clf,Xtrain,Ytrain,cv=5)
    # 每次交叉验证返回的得分数组,再求数组均值
    score.append(cvresult.mean()) 
    var.append(cvresult.var())

plt.plot(krange,score,color='k')
plt.plot(krange,np.array(score)+np.array(var)*2,c='red',linestyle='--')
plt.plot(krange,np.array(score)-np.array(var)*2,c='red',linestyle='--')

运行结果如下:
在这里插入图片描述
输出的图片如下所示:
在这里插入图片描述

是否需要验证集

最标准,最严谨的交叉验证应该有三组数据:训练集、验证集和测试集。当我们获取一组数据之后:

  • 先将数据集分成整体的训练集和测试集
  • 然后我们把训练集放入交叉验证中
  • 从训练集中分割更小的训练集(k-1 份)和验证机(1 份)
  • 返回的交叉验证结果其实是验证集上的结果
  • 使用验证集寻找最佳参数,确认一个我们认为泛化能力最佳的模型
  • 将这个模型使用在测试集上,观察模型的表现

通常来说,我们认为经过验证集找出最终参数后的模型的泛化能力是增强了的,因此模型在未知数据(测试集)上的效果会更好,但尴尬的是,模型经过交叉验证在验证集上的调参之后,在测试集上的结果没有变好的情况时有发生。

原因其实是:

  • 我们自己分的训练集和测试集,会影响模型的效果
  • 交叉验证后的模型的泛化能力增强了,表现它在未知数据集上方差更小,平均水平更高,但却无法保证它在现在分出来的测试集上预测能力最强
  • 如此来说,是否有测试集的存在,其实意义不大

如果我们相信交叉验证的调整结果是增强了模型的泛化能力,那即便测试集上的测试结果并没有变好(甚至变坏),我们也认为模型是成功的。
如果我们不相信交叉验证的调整结果能够增强模型的泛化能力,而一定要依赖测试集来进行判断,我们完全没有进行交叉验证的必要,直接用测试集上的结果用来跑学习曲线就好了。
所以,究竟是否需要验证集,其实是存在争议的,在严谨的情况下,大家还是使用了有验证集的方式。

其他交叉验证

交叉验证的方法不止“K 折”一种,分割训练集和测试集的方法也不止一种,分门别类的交叉验证占据了 sklearn 中非常长的一章。
所有的交叉验证都是在分割训练集和测试集,只不过侧重的方向不同。

  • K 折就是按顺序取训练集和测试集
  • ShuffleSpilt 就侧重于让测试集分布在数据的全方位之内
  • StratifiedKFold 则认为训练数据和测试数据必须在每个标签分类中占有相同的比例

各类交叉验证的原理繁琐,大家在机器学习的道路上一定会逐渐遇到更难的交叉验证,但是万变不离其宗:本质上交叉验证是为了解决训练集和测试集的划分对模型带来的影响,同时检测模型的泛化能力。

在这里插入图片描述

交叉验证的折数不可太大,因为折数越大抽出来的数据集越小,训练数据所带来的信息量就会越小,模型会越来越不稳定。

避免折数太大

如果你发现不使用交叉验证的时候模型表现很好,一使用校验验证模型的效果就骤降

  • 一定要查看你的标签是否有顺利
  • 然后就是查看你的数据量是否太小,折数是否太高

如果将上面例题的代码中将 cv 将 5 改成 100,代码如下所示:

score = []
var = []
#设置不同的k值,从1到19都看看
krange=range(1,20) 
for i in krange:
    clf = KNeighborsClassifier(n_neighbors=i)
    cvresult = CVS(clf,Xtrain,Ytrain,cv=100)
    # 每次交叉验证返回的得分数组,再求数组均值
    score.append(cvresult.mean()) 
    var.append(cvresult.var())

plt.plot(krange,score,color='k')
plt.plot(krange,np.array(score)+np.array(var)*2,c='red',linestyle='--')
plt.plot(krange,np.array(score)-np.array(var)*2,c='red',linestyle='--')

生成图片如下图所示:
在这里插入图片描述

折数过大:

  • 运算效率变慢
  • 预测率方差变大,难以保证在新的数据集达到预期预测率

http://www.kler.cn/a/372828.html

相关文章:

  • 信息奥赛一本通 1168:大整数加法
  • 游戏引擎学习第81天
  • 深度学习 Pytorch 基本优化思想与最小二乘法
  • C++第十五讲:异常
  • 通过图形界面展现基于本地知识库构建RAG应用
  • 蓝桥杯 Python 组知识点容斥原理
  • 视频设备一体化监控运维方案
  • openGauss开源数据库实战十四
  • Flink难点和高频考点:Flink的反压产生原因、排查思路、优化措施和监控方法
  • 性能测试——Jmeter实战
  • DAIN-SQL,DAIL-SQL,C3-SQL和 DIN-SQL 技术的理解和异同点
  • LSTM——长短期记忆神经网络
  • Linux 调度SCHED_FIFO或SCHED_RR
  • 传统机器学习总结
  • 目标检测一阶段模型
  • BERT的中文问答系统22
  • rook-ceph mon 报错 e9 handle_auth_request failed to assign global_id
  • 时尚零售企业商品计划管理的数字化之旅
  • 「C/C++」C++设计模式 之 抽象工厂模式(Abstract Factory)
  • HTTP相关返回值异常原因分析,第二部分
  • Mac在Typora配置PicGo图床,以github为例
  • rsync异地备份
  • 详解机器学习经典模型(原理及应用)——朴素贝叶斯
  • 【iOS】YYModel初学习
  • ssm014基于JSP的乡镇自来水收费系统+jsp(论文+源码)_kaic
  • 图书管理系统汇报