当前位置: 首页 > article >正文

第15章:ConvNeXt图像分类实战:遥感场景分类【包含本地网页部署、迁移学习】

目录

1. ConvNeXt 模型

2. 遥感场景建筑识别

2.1 数据集

2.2 训练参数

2.3 训练结果

2.4 本地部署推理

3. 下载


1. ConvNeXt 模型

ConvNeXt是一种基于卷积神经网络(CNN)的现代架构,由Facebook AI Research (FAIR) 团队在2022年提出。它通过借鉴Transformer的设计思想,对传统CNN进行了改进,使其在图像分类等任务中表现优异,甚至超越了Vision Transformers(ViT)

详细介绍:

基于ConvNeXt网络的图像识别-CSDN博客

核心思想

ConvNeXt的核心思想是将Transformer的成功设计理念(如ViT)引入CNN,同时保留卷积的固有优势。通过一系列现代化改进,ConvNeXt在保持高效性的同时,提升了性能。

主要改进

  1. 大卷积核:使用更大的卷积核(如7x7)来扩大感受野,类似于Transformer中自注意力机制捕捉全局信息的能力。

  2. 分层设计:采用类似ResNet的分层结构,逐步降低分辨率并增加通道数,以提取多尺度特征。

  3. 倒置瓶颈结构:借鉴MobileNetV2的倒置瓶颈设计,先扩展通道数再进行深度卷积,最后压缩通道数,提升计算效率。

  4. Layer Normalization:用Layer Normalization替换Batch Normalization,更适合小批量训练,并提升模型稳定性。

  5. GELU激活函数:使用GELU激活函数替代ReLU,因其在Transformer中的表现更佳。

  6. 减少激活和归一化层:减少不必要的激活和归一化层,简化网络结构,提升性能。

  7. Stochastic Depth:引入随机深度(Stochastic Depth),在训练时随机丢弃部分层,增强模型泛化能力。

2. 遥感场景建筑识别

ConvNeXt 实现的model部分代码如下面所示,这里如果采用官方预训练权重的话,会自动导入官方提供的最新版本(ImageNet)的权重

2.1 数据集

数据集文件如下:

具体图像示例:

标签如下:

{
    "0": "airport",
    "1": "bridge",
    "2": "church",
    "3": "forest",
    "4": "lake",
    "5": "river",
    "6": "skyscraper",
    "7": "stadium",
    "8": "statue",
    "9": "tower",
    "10": "urbanPark"
}

其中,训练集的总数为820,验证集的总数为345

2.2 训练参数

训练的参数如下:

    parser.add_argument("--model", default='tiny', type=str,help='tiny,small,base,large')
    parser.add_argument("--pretrained", default=True, type=bool)       # 采用官方权重
    parser.add_argument("--freeze_layers", default=True, type=bool)    # 冻结权重

    parser.add_argument("--batch-size", default=8, type=int)
    parser.add_argument("--epochs", default=30, type=int)

    parser.add_argument("--optim", default='AdamW', type=str,help='SGD,Adam,AdamW')         # 优化器选择

    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lrf',default=0.01,type=float)                  # 最终学习率 = lr * lrf

    parser.add_argument('--save_ret', default='runs', type=str)             # 保存结果
    parser.add_argument('--data_train',default='./data/train',type=str)           # 训练集路径
    parser.add_argument('--data_val',default='./data/val',type=str)               # 验证集路径

需要注意的是网络分类的个数不需要指定,摆放好数据集后,代码会根据数据集自动生成!

更换数据集的话,将data-train和data-val路径更改即可,一键运行!

trian脚本会在训练同时自动验证,生成训练和验证的曲线图和指标

网络模型信息如下:

 "train parameters": {
        "model version": "tiny",
        "pretrained": true,
        "freeze_layers": true,
        "batch_size": 8,
        "epochs": 30,
        "optim": "AdamW",
        "lr": 0.01,
        "lrf": 0.01,
        "save_folder": "runs"
    },
    "dataset": {
        "trainset number": 820,
        "valset number": 345,
        "number classes": 11
    },
    "model": {
        "total parameters": 27818891.0,
        "train parameters": 9995,
        "flops": 4463390208.0
    },

2.3 训练结果

所有的结果都保存在 save_ret 目录下,这里是 runs 

weights 下有最好和最后的权重,在训练完成后控制台会打印最好的epoch

这里只展示部分结果:可以看到网络没有完全收敛,增大epoch会得到更好的效果

最后一轮结果:

    "epoch:29": {
        "train info": {
            "accuracy": 0.9987804877926978,
            "airport": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "bridge": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "church": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "forest": {
                "Precision": 1.0,
                "Recall": 0.987,
                "Specificity": 1.0,
                "F1 score": 0.9935
            },
            "lake": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "river": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "skyscraper": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "stadium": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "statue": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "tower": {
                "Precision": 1.0,
                "Recall": 1.0,
                "Specificity": 1.0,
                "F1 score": 1.0
            },
            "urbanPark": {
                "Precision": 0.9872,
                "Recall": 1.0,
                "Specificity": 0.9987,
                "F1 score": 0.9936
            },
            "mean precision": 0.9988363636363636,
            "mean recall": 0.9988181818181818,
            "mean specificity": 0.9998818181818181,
            "mean f1 score": 0.9988272727272729
        },
        "valid info": {
            "accuracy": 0.8463768115696703,
            "airport": {
                "Precision": 0.9231,
                "Recall": 0.8571,
                "Specificity": 0.997,
                "F1 score": 0.8889
            },
            "bridge": {
                "Precision": 0.9032,
                "Recall": 0.8485,
                "Specificity": 0.9904,
                "F1 score": 0.875
            },
            "church": {
                "Precision": 0.7647,
                "Recall": 0.8125,
                "Specificity": 0.9878,
                "F1 score": 0.7879
            },
            "forest": {
                "Precision": 0.9259,
                "Recall": 0.7576,
                "Specificity": 0.9936,
                "F1 score": 0.8333
            },
            "lake": {
                "Precision": 0.9091,
                "Recall": 0.7692,
                "Specificity": 0.997,
                "F1 score": 0.8333
            },
            "river": {
                "Precision": 0.7872,
                "Recall": 0.9024,
                "Specificity": 0.9671,
                "F1 score": 0.8409
            },
            "skyscraper": {
                "Precision": 0.875,
                "Recall": 0.7,
                "Specificity": 0.997,
                "F1 score": 0.7778
            },
            "stadium": {
                "Precision": 0.8943,
                "Recall": 0.9649,
                "Specificity": 0.9437,
                "F1 score": 0.9283
            },
            "statue": {
                "Precision": 0.8095,
                "Recall": 0.6296,
                "Specificity": 0.9874,
                "F1 score": 0.7083
            },
            "tower": {
                "Precision": 0.7143,
                "Recall": 0.4167,
                "Specificity": 0.994,
                "F1 score": 0.5263
            },
            "urbanPark": {
                "Precision": 0.7,
                "Recall": 0.875,
                "Specificity": 0.9617,
                "F1 score": 0.7778
            },
            "mean precision": 0.8369363636363637,
            "mean recall": 0.7757727272727273,
            "mean specificity": 0.9833363636363637,
            "mean f1 score": 0.7979818181818181
        }
    }

训练集和测试集的混淆矩阵:

ROC曲线和auc值:

2.4 本地部署推理

推理是指没有标签,只有图片数据的情况下对数据的预测,这里使用了网页推理

值得注意的是,如果训练了自己的数据集,需要对infer脚本进行更改,如下:

# 参数
MODEL = 'tiny'
LABELS = r'D:\project\ConvNeXt全家桶\runs\class_indices.json'
PTH = r'D:\project\ConvNeXt全家桶\runs\weights\best.pth'
IMAGE_PATH = r'D:\project\ConvNeXt全家桶\data\train\airport\0.jpg'

运行:

streamlit run D:\project\ConvNeXt全家桶\infer.py

3. 下载

关于本项目代码和数据集、训练结果的下载:

基于ConVNeXt神经网络模型实现的迁移学习、图像识别项目:遥感场景分类网页推理资源-CSDN文库

关于Ai 深度学习图像识别、医学图像分割改进系列:AI 改进系列_听风吹等浪起的博客-CSDN博客

神经网络改进完整实战项目:改进系列_听风吹等浪起的博客-CSDN博客


http://www.kler.cn/a/588469.html

相关文章:

  • git subtree在本地合并子仓库到主仓库
  • KY-038 声音传感器如何工作以及如何将其与 ESP32 连接
  • java 线程池Executor框架
  • 深入解析 Vue 3 Teleport:原理、应用与最佳实践
  • 使用Inno Setup将Unity程序打成一个安装包
  • Native层逆向:ARM汇编与JNI调用分析
  • node.js-WebScoket心跳机制(服务器定时发送数据,检测连接状态,重连)
  • 游戏成瘾与学习动力激发策略研究——自我效能理论
  • 深入理解Linux网络随笔(七):容器网络虚拟化--Veth设备对
  • 基于javaweb的SSM+Maven网上选课管理系统设计与实现(源码+文档+部署讲解)
  • JavaScript性能优化的12种方式
  • Function 和 Consumer函数式接口
  • Ubuntu docker镜像恢复至原始文件
  • React使用路由表
  • 使用GoldenGate完成SQLserver到Oracle的数据实时同步
  • Django项目之订单管理part3
  • Markdig:强大的 .NET Markdown 解析器详解
  • 【AI时代移动端安全开发实战:从基础防护到智能应用】
  • 责任链模式:优雅处理请求的设计艺术
  • k8s 网络基础解析