当前位置: 首页 > article >正文

落地级分类模型训练框架搭建(1):resnet18/50和mobilenetv2在CIFAR10上测试结果

目录

前言

1.分类结果测试汇总

2.训练过程可视化

ResNet18直接训练(准确率、召回率、Loss)

ResNet50直接训练(准确率、召回率、Loss)

3.模型权重分析

一般的训练,获取的模型权重分布

引入约束化训练的模型权重分布​编辑

4.测试结果

模型[6] ResNet50测试结果

模型[7] ResNet18测试结果

模型[10] ResNet18_009测试结果


前言

用CIFAR10测试一下自己搭建的分类模型训练框架,包括基本训练、知识蒸馏、模型稀疏化、剪枝微调、模型量化。最终,ResNet18取得96%的分类准确率,且剪枝91%的ResNet18可以取得94%的准确率。

1.分类结果测试汇总

说明:严格划分了训练集、验证集和测试集,没有将CIFAR10的测试集当做验证集

训练测试集为CIFAR10训练集随机采样95%,验证集为训练集剩余的5%,测试集为原CIFAR10的测试集。实验结果如下:

序号模型名称

训练

方式

输入尺寸

(H x W)

参数量

(M)

准确率召回率

耗时

(3060)(ms)

Flops

(G)

吞吐量(3060)

(K)

1MobileNetv3-[64, 64]4.21592.1292.121.7060.21612.3
2ResNet18-[64, 64]11.18291.9891.990.8300.14912.6
3ResNet18L1[64, 64]11.18292.3192.340.8410.14913.0
4ResNet18-bestL2[224, 224]11.18295.1295.011.2061.8241.6
5ResNet18-lastL2[224, 224]11.18295.4195.391.2671.8241.6
6ResNet50L2[224, 224]23.52996.0496.012.8974.1320.55
7ResNet18L2+di[224, 224]11.18296.0195.891.1991.8241.6
8ResNet18L1+di[224, 224]11.18295.5995.551.2041.8241.6
9ResNet18_009di[224, 224]0.96193.6393.610.6950.4623.4
10ResNet18_009di+fu[224, 224]0.96193.9893.950.6920.4623.4

说明L1表示引入L1正则,L2表示引入L2正则,di表示使用知识蒸馏,fu表示微调,009表示参数仅有原模型的9%,即剪枝了91%的参数。

测试结论

        1. CIFAR10上,不使用预训练模型,ResNet18准确率一般在92%~94%,使用预训练模型可以到95%,性能极限在96%左右。(修改第一层卷积核大小可以再提高一些)

        2.框架使用ImageNet上的预训练模型,没有修改模型,就能达到95.41%(模型[5]),说明整个训练框架能够充分训练,且很难发生过拟合(最后一轮模型[5]比最好一轮模型[4]测试结果还高,且验证精度基本等于测试精度)。

        3.使用知识蒸馏可以显著提高模型训练收敛速度和最终测试精度,模型[7]~[9]均使用模型[6]作为teacher model,且模型[6]仅直接训练一次,还未进行交叉验证的微调。

        4.利用L2正则实现权重衰减,再使用L1实现权重置0剪枝,可以使ResNet18稀疏性达到91%,最终也剪枝了91%,导致了性能下降,实际剪枝84%可能更好。

2.训练过程可视化

ResNet18直接训练(准确率、召回率、Loss)

ResNet50直接训练(准确率、召回率、Loss)

结论:使用一系列的数据增强(随机翻转、旋转、裁剪、锐化、仿射变换、对比度调整、区域高斯模糊、区域翻转、区域打码、各种噪声等),外加正则化(L1、L2、dropout、batchnorm等),指数衰减学习率,模型基本不会过拟合。因此,可以对一个数据集使用交叉验证,尽可能拟合,即可提高同源数据的预测准确率(模型[9]到模型[10]的提升)。

3.模型权重分析

一般的训练,获取的模型权重分布

引入约束化训练的模型权重分布

结论:引入约束化训练,使得整体权重更接近高斯分布,却不会出现离群值,更方便后续的模型量化。

4.测试结果

模型[6] ResNet50测试结果

Class airplane: Precision: 0.970884, Recall: 0.967000, F1-Score: 0.968938
Class automobile: Precision: 0.981763, Recall: 0.969000, F1-Score: 0.975340
Class bird: Precision: 0.947264, Recall: 0.952000, F1-Score: 0.949626
Class cat: Precision: 0.900771, Recall: 0.935000, F1-Score: 0.917566
Class deer: Precision: 0.966169, Recall: 0.971000, F1-Score: 0.968579
Class dog: Precision: 0.941117, Recall: 0.927000, F1-Score: 0.934005
Class frog: Precision: 0.983690, Recall: 0.965000, F1-Score: 0.974255
Class horse: Precision: 0.988753, Recall: 0.967000, F1-Score: 0.977755
Class ship: Precision: 0.970238, Recall: 0.978000, F1-Score: 0.974104
Class truck: Precision: 0.953786, Recall: 0.970000, F1-Score: 0.961824
Class macro avg: Precision: 0.960443, Recall: 0.960100, F1-Score: 0.960199

模型[7] ResNet18测试结果

Class airplane: Precision: 0.964215, Recall: 0.970000, F1-Score: 0.967099
Class automobile: Precision: 0.975075, Recall: 0.978000, F1-Score: 0.976535
Class bird: Precision: 0.954683, Recall: 0.948000, F1-Score: 0.951330
Class cat: Precision: 0.895494, Recall: 0.934000, F1-Score: 0.914342
Class deer: Precision: 0.963964, Recall: 0.963000, F1-Score: 0.963482
Class dog: Precision: 0.947040, Recall: 0.912000, F1-Score: 0.929190
Class frog: Precision: 0.968379, Recall: 0.980000, F1-Score: 0.974155
Class horse: Precision: 0.992828, Recall: 0.969000, F1-Score: 0.980769
Class ship: Precision: 0.976884, Recall: 0.972000, F1-Score: 0.974436
Class truck: Precision: 0.962376, Recall: 0.972000, F1-Score: 0.967164
Class macro avg: Precision: 0.960094, Recall: 0.959800, F1-Score: 0.959850

模型[10] ResNet18_009测试结果

Class airplane: Precision: 0.963001, Recall: 0.937000, F1-Score: 0.949823
Class automobile: Precision: 0.969031, Recall: 0.970000, F1-Score: 0.969515
Class bird: Precision: 0.929949, Recall: 0.916000, F1-Score: 0.922922
Class cat: Precision: 0.862275, Recall: 0.864000, F1-Score: 0.863137
Class deer: Precision: 0.937870, Recall: 0.951000, F1-Score: 0.944389
Class dog: Precision: 0.886076, Recall: 0.910000, F1-Score: 0.897879
Class frog: Precision: 0.949654, Recall: 0.962000, F1-Score: 0.955787
Class horse: Precision: 0.968813, Recall: 0.963000, F1-Score: 0.965898
Class ship: Precision: 0.975635, Recall: 0.961000, F1-Score: 0.968262
Class truck: Precision: 0.955268, Recall: 0.961000, F1-Score: 0.958126
Class macro avg: Precision: 0.939757, Recall: 0.939500, F1-Score: 0.939574


http://www.kler.cn/a/519262.html

相关文章:

  • vim的多文件操作
  • Git进阶笔记系列(01)Git核心架构原理 | 常用命令实战集合
  • Spring整合Mybatis、junit纯注解
  • RV1126画面质量四:GOP改善画质
  • 物业管理平台系统提升社区智能化服务效率与管理水平
  • hedfs和hive数据迁移后校验脚本
  • 高级java每日一道面试题-2025年01月24日-框架篇[SpringBoot篇]-如何理解 Spring Boot 中的 Starters(启动器) ?
  • three.js+WebGL踩坑经验合集(4.1):THREE.Line2的射线检测问题(注意本篇说的是Line2,同样也不是阈值方面的问题)
  • 多模态论文笔记——ViViT
  • OpenAI-Edge-TTS的使用
  • 深入解析Gradle项目发布配置:从构建到仓库部署
  • SAM 2运行笔记
  • 为AI聊天工具添加一个知识系统 之69 详细设计 之10 三种中台和时间度量 之2
  • 5.如何减少顶点数
  • 【elasticsearch】reindex 操作将索引的数据复制到另一个索引
  • 【2024年华为OD机试】 (A卷,200分)- 几何平均值最大子数组(JavaScriptJava PythonC/C++)
  • 《CPython Internals》阅读笔记:p356-p359
  • Spring 框架基础:IOC 与 AOP 原理剖析及面试要点
  • Spring Boot 无缝集成SpringAI的函数调用模块
  • android12源码中用第三方APK替换原生launcher
  • 半小时速通flume-flume正文学习
  • 【深入理解SpringCloud微服务】Sentinel源码解析——DegradeSlot熔断规则
  • 【漫话机器学习系列】060.前馈神经网络(Feed Forward Neural Networks, FFNN)
  • 能源新动向:智慧能源平台助力推动新型电力负荷管理系统建设
  • 面试技巧——压力面题目与参考答案
  • 软件越跑越慢的原因分析