当前位置: 首页 > article >正文

人工智能知识分享第三天-机器学习中交叉验证和网格搜索

交叉验证和网格搜索

  • 交叉验证解释:

    • 概述:
      它是一种更加完善的, 可信度更高的模型预估方式, 思路是: 把数据集分成n份, 每次都取1份当做测试集, 其它的当做训练集, 然后计算模型的: 评分.
      然后再用下1份当做测试集, 其它当做训练集, 计算模型评分, 分成几份, 就进行几次计算, 最后计算: 所有评分的均值, 当做模型的最终评分.
      好处:
      交叉验证的结果 比 单一切分训练集和测试集 获取评分结果, 可信度要高.
      细节:
      1.把数据集分成N份, 就叫: N折交叉验证, 例如: 分成4份, 就叫4折交叉验证.
      2.交叉验证一般不会单独使用, 而是结合网格搜索一起使用.
  • 网格搜索:

    • 概述:
      它是机器学习内置的API, 一般结合交叉验证一起使用, 指的是: GridSearchCV, 目的是: 寻找最优超参.
      格式:
      GridSearchCV(estimator模型对象, params = [超参可能出现的值1, 值2, 值3…], cv=折数)
      超参解释:
      模型中需要用户(程序员)手动传入的参数 -> 统称为: 超参.
      目的:
      网格搜索 + 交叉验证 目的都是为了 寻找模型的 最优的解决方案, 追求较高的效率.
      代码演示:

# 导入工具包
from sklearn.datasets import load_iris  # 加载鸢尾花测试集的.
from sklearn.model_selection import train_test_split, GridSearchCV  # 分割训练集和测试集的,  网格搜索 + 交叉验证.
from sklearn.preprocessing import StandardScaler  # 数据标准化的
from sklearn.neighbors import KNeighborsClassifier  # KNN算法 分类对象
from sklearn.metrics import accuracy_score  # 模型评估的, 计算模型预测的准确率

# 1. 加载鸢尾花数据集.
iris_data = load_iris()
# 2. 数据预处理.
x_train, x_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.2, random_state=22)
# 3. 特征工程, 数据标准化(特征预处理)
transfer = StandardScaler()
# 分别对训练集 和 测试集进行标准化处理.
x_train = transfer.fit_transform(x_train)  # 训练(拟合) + 转换
x_test = transfer.transform(x_test)  # 只进行转换, 不进行拟合.

# 4. 模型训练.
# 4.1 创建KNN 分类器对象 -> 模型对象.
estimator = KNeighborsClassifier()  # 注意: 不要传入超参, 即: K的数值.
# 4.2 定义字典, 记录: 超参可能出现的值.
param_dict = {'n_neighbors': list(range(1, 21))}    # K的范围: 1 ~ 21 包左不包右
# 4.3 创建网格搜索对象.
# 参1: 模型对象, 参2: 超参字典, 参3: 交叉验证的折数.
# 处理完之后, 会返回1个功能更加强大的模型对象.
estimator = GridSearchCV(estimator, param_dict, cv=4)

# 4.4 模型训练.
estimator.fit(x_train, y_train)

# 5. 模型预测.
y_predict = estimator.predict(x_test)
print(f'预测结果是: {y_predict}')

# 6. 打印下 网格搜索 和 交叉验证的结果.
print(estimator.best_score_)        # 最优组合的 平均分,    0.9666666666666666
print(estimator.best_estimator_)    # 最优组合的 模型对象
print(estimator.best_params_)       # 最优组合的 超参数(供参考) {'n_neighbors': 3}
print(estimator.cv_results_)        # 所有组合的 评分结果(过程)
print('-' * 22)

# 7. 结合上述的结果, 最模型再次做评估.
estimator = KNeighborsClassifier(n_neighbors=3)
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)
print(f'模型准确率为: {accuracy_score(y_test, y_predict)}')   # 0.9666666666666667

坚持分享 共同进步 如有错误 欢迎指出


http://www.kler.cn/a/458175.html

相关文章:

  • Java List 集合详解:基础用法、常见实现类与高频面试题解析
  • 一种基于动态部分重构的FPGA自修复控制器
  • C语言宏和结构体的使用代码
  • 使用three.js 实现vr全景图展示,复制即可用
  • 最小特权的例子
  • MIT Cheetah 四足机器人的动力学及算法 (I) —— 简化动力学模型
  • uniapp不能直接修改props的数据原理浅析
  • 《Virt A Mate(VAM)》免安装豪华版v1.22中文汉化整合
  • Nacos配置管理+共享配置、配置热更新
  • [Unity Shader][Unity Shader][图形渲染]Shader数学基础19-选择使用3×3或4×4变换矩阵的技巧
  • 音视频入门基础:MPEG2-TS专题(23)——通过FFprobe显示TS流每个packet的信息
  • 设计宝藏解压密码
  • 单片机优先级
  • Java实现简单爬虫——爬取疫情数据
  • 定义Shape:打造属于你的独特图形
  • YOLOv10目标检测-训练自己的数据
  • LAION-SG:一个大规模、高质量的场景图结构注释数据集,为图像-文本模型训练带来了革命性的进步。
  • leecode377.组合总和IV
  • 【MySQL】十三,关于MySQL的全文索引
  • jangow靶机
  • 【探花交友】day01—项目介绍与环境搭建
  • 10道JavaWeb常问面试题
  • Dify服务器部署教程
  • Python中构建终端应用界面利器——Blessed模块
  • QT笔记- QTreeView + QFileSystemModel 当前位置的保存与恢复 #选中 #保存当前索引
  • 2025年我国网络安全发展形势展望