当前位置: 首页 > article >正文

Vision Transformer 神经网络在水果、动物、血细胞上的应用【深度学习、PyTorch、图像分类】

源链接(英文):https://www.orzzz.net/directory/codes/VisionTransformer/index.html
代码链接(中文):https://github.com/Illusionna/VisionTransformer

1. 介绍

2023-11-25 我在 GitHub 上提交了 EfficientTransformer 的比赛代码,近一年后的今天,我在删除电脑磁盘上吃灰的文件时,再次发现到它。由于去年写的函数接口很简约,所以这次重写 ViT 代码,修改了很多冗余的语句,并在其他图片集上进行测试,没想到这个神经网络模型效果依然这么好。

该项目代码是用于图片集训练、测试、预测的 Vision Transformer 神经网络架构。你可以下载已提供的数据集,当然,你也可以将代码运行在自己的图片集上。然而,需要注意的是,该项目仅是 demo 演示,不建议你将代码部署工业生产。

2. 数据集

已提供的图片集有三种:动物、水果、血细胞。你可以自定义新的图片集,不过需要满足一定的规则,否则代码可能无法运行。参考已提供的三种图片集文件夹或文件放置结构即可,发现其中的规律对你来说轻而易举?

三种图片集共约 1 GB 磁盘占用,所以 git 仓库里只放置了“水果”的图片集压缩包,另外两种数据集推荐使用百度网盘下载。在你运行代码前记得解压缩哦。

  • 血细胞(下载链接 1 路由 BaiduNetdisk 提取码 nl86、下载链接 2 路由 Amazon):包含 neutrophils、eosinophils、basophils、lymphocytes、monocytes、immature granulocytes、erythroblasts、platelets 八类。
  • 动物(下载链接 BaiduNetdisk 提取码 nl86):包括猫、狗两类。
  • 水果(下载链接 BaiduNetdisk 提取码 nl86):包括 Apple、Carambola、Pear、Plum、Tomatoes 五类。

说句题外话,你可以把这三类图片集整合到一块,形成一个共计 15 类的更大数据集,再尝试训练测试 ViT 模型。项目工作区结构如下:

在这里插入图片描述

项目结构文件的描述如下:

.
├── cache
│   ├── log
│   │   ├── Animals-train_loss(0.48440)valid_loss(0.40608).pt (动物图片训练 200 次后的权重)
│   │   └── Fruits-train_loss(0.12103)valid_loss(0.04023).pt (该权重可以直接用于测试、预测)
│   ├── Animals-info.json (动物数据集详情)
│   ├── Animals-map.json (动物数据集划分映射)
│   ├── Animals-net.txt (动物图片集对应的神经网络结构)
│   ├── Animals-test.json
│   ├── Animals-train.json
│   ├── Fruits-info.json
│   ├── Fruits-map.json
│   ├── Fruits-net.txt
│   ├── Fruits-test.json (水果测试集正确率)
│   └── Fruits-train.json (水果训练集训练过程)
├── datasets
│   ├── Animals (动物数据集)
│   │   ├── cat (猫类)
│   │   └── dog (狗类)
│   ├── Bloodcells
│   │   ├── basophil
│   │   ├── eosinophil
│   │   ├── erythroblast
│   │   ├── ig
│   │   ├── lymphocyte
│   │   ├── monocyte
│   │   ├── neutrophil
│   │   ├── platelet
│   │   └── README.md (血细胞数据集介绍, 不必管它)
│   ├── Fruits (水果数据集)
│   │   ├── Apple (苹果类)
│   │   │   ├── Apple (1).jpg (苹果类的某张图片)
│   │   │   ├── Apple (2).jpg
│   │   │   └── Apple (3).jpg
│   │   ├── Carambola
│   │   ├── Pear (梨子类)
│   │   │   ├── Pear (1).jpg
│   │   │   └── Pear (2).jpg (梨子类的某张照片)
│   │   ├── Plum
│   │   └── Tomatoes
│   ├── Unknown-Fruits (未知水果, 被用于预测)
│   │   ├── Fruit (a).jpg
│   │   ├── Fruit (b).jpg
│   │   ├── Fruit (c).jpg
│   │   ├── Fruit (d).jpg (某个未知的水果, 使用 Predict.py 代码进行预测)
│   │   └── Fruit (e).jpg
│   ├── Fruits.zip (水果数据集压缩包, 解压后, 顶替 Fruits 文件夹所有示例文件)
│   └── Unknown-Fruits.zip (未知水果的压缩包, 解压缩后, 顶替 Unknown-Fruits 文件夹)
├── utils
│   ├── einops (一个面向眼睛编程的张量库, 巨牛)
│   │   ├── experimental
│   │   │   ├── __init__.py
│   │   │   └── indexing.py
│   │   ├── layers
│   │   │   ├── __init__.py
│   │   │   ├── _einmix.py
│   │   │   ├── chainer.py
│   │   │   ├── flax.py
│   │   │   ├── keras.py
│   │   │   ├── oneflow.py
│   │   │   ├── paddle.py
│   │   │   ├── tensorflow.py
│   │   │   └── torch.py
│   │   ├── __init__.py
│   │   ├── _backends.py
│   │   ├── _torch_specific.py
│   │   ├── array_api.py
│   │   ├── einops.py
│   │   ├── packing.py
│   │   ├── parsing.py
│   │   └── py.typed
│   ├── linformer (一个把 Transformer 降低成线性复杂度的库)
│   │   ├── __init__.py
│   │   ├── linformer.py
│   │   └── reversible.py
│   ├── efficient.py (ViT 架构模型类)
│   ├── interface.py (自定义的接口文件)
│   ├── net.py (自定义的神经网络结构文件)
│   └── tool.py (自定义的工具文件)
├── Illustrate.py (插图函数, 需要单独安装库 >>> pip install matplotlib)
├── Predict.py (预测函数, 如果测试效果很好, 可将权重用于预测未知图片集)
├── Preprocessing.py (预处理函数, 对数据集进行处理, 之后可进行训练、测试、预测)
├── REAME.md (你正在阅读的文件)
├── Test.py (测试函数, 用于检测训练好的结果在测试集上的正确率)
└── Train.py (预处理完成后, 即可进行训练)

3. 依赖

该项目既支持 CPU 计算也支持 GPU 计算,如果你的电脑有图形处理器,可在终端使用 nvidia-smi 指令查看显卡型号,我的 GPU 显存是 4 GB 大小,且指令返回 CUDA 版本是 12.5 的型号。

NVIDIA-SMI 556.12   Driver Version: 556.12  CUDA Version: 12.5

你可在 PyTorch 官网找到适合自己电脑的 Python 库,所以我在 Windows 11 上安装 torch-cu121 库。

在这里插入图片描述

推荐使用 Anaconda、Miniconda 或者 venv 虚拟环境下载 Python 3.10+ 的依赖库。

Conda

  • conda create -n GPU python==3.10.0
    
  • conda activate GPU
    
  • pip install matplotlib
    
  • python Preprocessing.py 
    

venv

  • python -m venv venv
    
  • ./venv/Scripts/activate
    
  • ./venv/Scripts/pip.exe install matplotlib
    
  • ./venv/Scripts/python.exe Preprocessing.py 
    

在这里插入图片描述

Windows 11

  • pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
    

Windows 10

  • pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
    
  • pip install pillow
    

macOS

  • pip install torch torchvision torchaudio
    
  • pip install numpy
    
  • pip install pillow
    

Linux

  • I don't know.
    

4. 数据集预处理

在正式训练之前需要执行 Preprocessing.py 文件,对图片进行统计和处理。

import os
from utils.tool import RandomSeed, Process

if __name__ == '__main__':
    os.system('')	# 使得终端 ASCII 颜色可以正常转义.
    print('\033[H\033[J', end = '')		# 清屏.
    RandomSeed(42)    # 播下随机种子, 使得每次执行程序结果相同.
    Process(
        dir = './datasets/Animals',  # 需要处理的图片集目录.
        ratio = [7, 2, 1], # 训练集、测试集、验证集比例, 不限格式.
        show = True,    # 展示预处理结果.
        resolution = True   # 统计图片平均分辨率.
    )
>>> python ./Preprocessing.py

在这里插入图片描述

5. 训练

"./cache/log" 文件夹的 Animals.pt、Fruits.pt 和 Bloodcells.pt 是我已经训练 200 轮后,得到了非常好的权重文件,你可以直接拿这两个权重文件进行测试、预测,跳过训练环节。毕竟,训练是很耗时间的。

如果你想重头开始进行水果、动物、血细胞图片集训练,或者对自己的数据集进行训练,你可以在 Preprocessing.py 完成后,执行 Train.py 文件。

import torch
import platform
from utils.net import Transformer
from utils.interface import TrainActivate

if __name__ == '__main__':
    DATASET_NAME = 'Fruits' # 图片集的文件夹名称, 即训练的类型.
    device = torch.device(
        ('mps' if platform.system() == 'Darwin' else 'cuda')
        if torch.cuda.is_available() else 'cpu'
    )   # 程序计算使用的处理器.
    net = Transformer(info_path = f'./cache/{DATASET_NAME}-info.json').to(device)
    optimizer = torch.optim.Adam(params = net.parameters(), lr = 3e-5)
    parameters: dict = {
        'train_test_valid_set_map': f'./cache/{DATASET_NAME}-map.json',
        'epoch': 200,   # 迭代次数.
        'batch_size': 256,  # 处理器一次性计算的图片个数.
        'model': net,
        'device': device,
        'optimizer': optimizer,
        'scheduler': torch.optim.lr_scheduler.StepLR(
            optimizer = optimizer,
            step_size = 1,
            gamma = 0.7
        ),
        'criterion': torch.nn.CrossEntropyLoss()
    }
    TrainActivate(parameters)
>>> python ./Train.py

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

上面这五张连续图片是我对 Animals、Fruits 和 Bloodcells 三类图片集各训练 200 次的过程,其中 Animals 耗时 03:31:17 较长, Fruits 训练耗时 01:54:29 较短,Bloodcells 训练耗时 02:25:42 适中。

程序代码自动将训练结果全部保存在 "./cache" 目录下,其中子文件夹 log 是存放训练权重,你可以看到 log 里的动物和水果权重,这是我训练出的最优权重。

6. 测试

Illustrate.py 文件是用于绘制训练过程的插图,你可以根据它的结果找一个笼统的区间范围,再从区间内找一个最优的训练权重文件。当然,你可以直接使用 "./cache/log" 目录中我训练好的两个最优权重进行测试。

>>> python ./Illustrate.py

在这里插入图片描述

import torch
import platform
from utils.interface import TestActivate

if __name__ == '__main__':
    params = {
        'batch_size': 128,  # 处理器一次计算测试的图片个数.
        # 映射字典路径.
        'train_test_valid_set_map': './cache/Fruits-map.json',
        # 权重文件路径.
        'weight': './cache/log/Fruits-train_loss(0.12103)valid_loss(0.04023).pt',
        'device': torch.device(
            ('mps' if platform.system() == 'Darwin' else 'cuda')
            if torch.cuda.is_available() else 'cpu'
        )
    }
    TestActivate(params)
>>> python ./Test.py

最终测试结果会返回并保存一个正确率指数,水果测试集达到 98% 的正确率,相当高。

在这里插入图片描述

7. 预测

水果 Fruits 测试集正确率高达 98%,因此表明训练的权重相当好,那么我们可将这个权重用于未知图片集的预测。

import torch
import platform
from utils.interface import PredictActivate

if __name__ == '__main__':
    params = {
        'batch_size': 1536, # 处理器一次性预测计算的图片个数.
        # 数据集详情文件路径.
        'info_path': './cache/Fruits-info.json',
        # 待预测的图片集目录.
        'predict_images_dir': './datasets/Unknown-Fruits',
        # 使用训练好的权重文件路径.
        'weight': './cache/log/Fruits-train_loss(0.12103)valid_loss(0.04023).pt',
        'device': torch.device(
            ('mps' if platform.system() == 'Darwin' else 'cuda')
            if torch.cuda.is_available() else 'cpu'
        )
    }
    PredictActivate(params)
>>> python ./Predict.py

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

8. 开源同步

GPLv3

https://orzzz.net

GitHub Repository


http://www.kler.cn/news/365797.html

相关文章:

  • 一站式学习 Shell 脚本语法与编程技巧,踏出自动化的第一步
  • [Redis] Redis数据持久化
  • arcgis中dem转模型导入3dmax
  • grafana 和 prometheus
  • 时间序列预测(十)——长短期记忆网络(LSTM)
  • H7-TOOL的LUA小程序教程第15期:电压,电流,NTC热敏电阻以及4-20mA输入(2024-10-21,已经发布)
  • 【web安全】缓慢的HTTP拒绝服务攻击详解
  • 从零开始:Python与Jupyter Notebook中的数据可视化之旅
  • 从PDF文件中提取数据
  • 理解深度学习模型——高级音频特征表示的分层理解
  • 【Linux系统】如何证明进程的独立性
  • 追寻数组的轨迹,解开算法的情愫
  • 【CCL】浅析 CFX Command Language
  • Java应用程序的测试覆盖率之设计与实现(一)-- 总体设计
  • 51单片机的学习之路1
  • ArcGIS001:ArcGIS10.2安装教程
  • 【实战案例】Django框架连接并操作数据库MySQL相关API
  • Obsidian·Zotero·无法联动问题
  • FFMPEG+Qt 实时显示本机USB摄像头1080p画面以及同步录制mp4视频
  • Svelte 5 正式发布:新一代前端框架!
  • 深入理解 Python 中的 threading.Lock:并发编程的基石
  • 一站式学习 Shell 脚本语法与编程技巧,踏出自动化的第一步
  • 桂城真题长方形
  • 计算机网络原理总结B-数据链路层
  • [CSP-J 2023] 一元二次方程(模拟)
  • 系统架构图设计(行业领域架构)