当前位置: 首页 > article >正文

深度学习实战基础案例——卷积神经网络(CNN)基于DenseNet的眼疾检测|第4例

文章目录

  • 前言
  • 一、数据准备
  • 二、项目实战
    • 2.1 设置GPU
    • 2.2 数据加载
    • 2.3 数据预处理
    • 2.4 数据划分
    • 2.5 搭建网络模型
    • 2.6 构建densenet121
    • 2.7 训练模型
    • 2.8 结果可视化
  • 三、UI设计
  • 四、结果展示
  • 总结

前言

在当今社会,眼科疾病尤其是白内障对人们的视力健康构成了严重威胁。白内障是全球范围内导致失明的主要原因之一,早期准确的诊断对于疾病的治疗和患者的预后至关重要。传统的白内障检测方法主要依赖于眼科医生的专业判断,这不仅需要大量的人力和时间,而且诊断结果可能会受到医生经验和主观因素的影响。
随着深度学习技术的飞速发展,其在医疗图像分析领域展现出了巨大的潜力。卷积神经网络(CNN)作为深度学习中的重要模型,已经在多种医疗图像识别任务中取得了显著的成果,如肿瘤检测、疾病分类等。利用 CNN 对眼科图像进行分析,可以辅助医生更快速、准确地进行疾病诊断。
本文将详细介绍如何使用基于 DenseNet 的卷积神经网络进行白内障疾病检测。通过这个实战案例,不仅可以帮助读者了解 DenseNet 的原理和应用,还能掌握利用深度学习进行医疗图像分析的基本流程和方法,为进一步开展相关研究和实践提供参考。

一、数据准备

本案例使用的数据集是retina_dataset|眼科疾病数据集。
数据集下载地址:点击这里
Retina Dataset的构建基于眼底图像的分类需求,涵盖了四种主要的眼科疾病类别:正常、白内障、青光眼和视网膜疾病。数据集通过收集和整理不同患者的视网膜图像,确保每类疾病均有代表性样本。图像数据经过标准化处理,以保证在不同设备和条件下获取的图像具有一致性,从而为后续的分类和分析提供了坚实的基础。

在这里插入图片描述

二、项目实战


我的环境:

  • 基础环境:Python3.9
  • 编译器:PyCharm 2024
  • 深度学习框架:Pytorch2.0

2.1 设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

2.2 数据加载

import os,PIL,random,pathlib

data_dir = '数据路径'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]

2.3 数据预处理

train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

test_transform = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(data_dir,transform=train_transforms)

2.4 数据划分

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

2.5 搭建网络模型

import torch.nn as nn
import torch
from torch import mean, max

class _DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate=0):
        super(_DenseLayer, self).__init__()
        self.drop_rate = drop_rate
        self.dense_layer = nn.Sequential(
            nn.BatchNorm2d(num_input_features),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_input_features, out_channels=bn_size * growth_rate, kernel_size=1, stride=1,
                      padding=0),
            Inceptionnext(bn_size * growth_rate, bn_size * growth_rate, kernel_size=3),
            CBAMBlock("FC", 5, channels=bn_size * growth_rate, ratio=9),
            nn.Conv2d(in_channels=bn_size * growth_rate, out_channels=growth_rate, kernel_size=1, stride=1, padding=0)
        )
        self.dropout = nn.Dropout(p=self.drop_rate)

    def forward(self, x):
        y = self.dense_layer(x)
        if self.drop_rate > 0:
            y = self.dropout(y)

        return torch.concat([x, y], dim=1)


class _DenseBlock(nn.Module):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate=0):
        super(_DenseBlock, self).__init__()
        layers = []
        for i in range(num_layers):
            layers.append(_DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class _TransitionLayer(nn.Module):
    def __init__(self, num_input_features, num_output_features):
        super(_TransitionLayer, self).__init__()
        self.transition_layer = nn.Sequential(
            nn.BatchNorm2d(num_input_features),
            nn.ReLU(),
            nn.Conv2d(in_channels=num_input_features, out_channels=num_output_features, kernel_size=1, stride=1,
                      padding=0),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        return self.transition_layer(x)


class DenseNet(nn.Module):
    def __init__(self, num_init_features=64, growth_rate=32, blocks=(6, 12, 24, 16), bn_size=4, drop_rate=0,
                 num_classes=1000):
        super(DenseNet, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=num_init_features, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        num_features = num_init_features
        self.layer1 = _DenseBlock(num_layers=blocks[0], num_input_features=num_features, growth_rate=growth_rate,
                                  bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[0] * growth_rate
        self.transtion1 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)

        num_features = num_features // 2
        self.layer2 = _DenseBlock(num_layers=blocks[1], num_input_features=num_features, growth_rate=growth_rate,
                                  bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[1] * growth_rate
        self.transtion2 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)

        num_features = num_features // 2
        self.layer3 = _DenseBlock(num_layers=blocks[2], num_input_features=num_features, growth_rate=growth_rate,
                                  bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[2] * growth_rate
        self.transtion3 = _TransitionLayer(num_input_features=num_features, num_output_features=num_features // 2)

        num_features = num_features // 2
        self.layer4 = _DenseBlock(num_layers=blocks[3], num_input_features=num_features, growth_rate=growth_rate,
                                  bn_size=bn_size, drop_rate=drop_rate)
        num_features = num_features + blocks[3] * growth_rate

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.features(x)

        x = self.layer1(x)
        x = self.transtion1(x)
        x = self.layer2(x)
        x = self.transtion2(x)
        x = self.layer3(x)
        x = self.transtion3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        y = torch.flatten(x, start_dim=1)
        x = self.fc(y)

        return x

2.6 构建densenet121

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

densenet121 = DenseNet(blocks=(6,12,24,16),
                       num_classes=len(classeNames))  

model = densenet121.to(device)

2.7 训练模型

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss

import copy

optimizer  = torch.optim.Adam(model.parameters(), lr= 1e-4)
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数

epochs     = 20

train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

best_acc = 0    # 设置一个最佳准确率,作为最佳模型的判别指标

for epoch in range(epochs):
    
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc   = epoch_test_acc
        best_model = copy.deepcopy(model)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(best_model.state_dict(), PATH)

print('Done')

2.8 结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Test Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Test Loss')
plt.show()

在这里插入图片描述

三、UI设计

这里使用QT Designder设计了一个简易的UI界面,可以很方便的进行使用。
在这里插入图片描述
UI.py文件如下


import test
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtGui import QIcon
from PyQt5.QtWidgets import QFileDialog, QMessageBox

imgName=None


class Ui_Form(object):
    def setupUi(self, Form):
        Form.setObjectName("Form")
        Form.resize(649, 559)
        self.label = QtWidgets.QLabel(Form)
        self.label.setGeometry(QtCore.QRect(120, 20, 331, 41))
        self.label.setStyleSheet("\n"
"font: 22pt \"华文彩云\";")
        self.label.setObjectName("label")
        self.label_2 = QtWidgets.QLabel(Form)
        self.label_2.setGeometry(QtCore.QRect(120, 130, 311, 251))
        self.label_2.setStyleSheet("border-image: url(:/新前缀/img.png);")
        self.label_2.setText("")
        self.label_2.setObjectName("label_2")
        self.label_4 = QtWidgets.QLabel(Form)
        self.label_4.setGeometry(QtCore.QRect(60, 440, 72, 15))
        self.label_4.setObjectName("label_4")
        self.textEdit = QtWidgets.QTextEdit(Form)
        self.textEdit.setGeometry(QtCore.QRect(140, 440, 211, 91))
        self.textEdit.setObjectName("textEdit")
        self.layoutWidget = QtWidgets.QWidget(Form)
        self.layoutWidget.setGeometry(QtCore.QRect(60, 400, 221, 31))
        self.layoutWidget.setObjectName("layoutWidget")
        self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.layoutWidget)
        self.horizontalLayout_2.setContentsMargins(0, 0, 0, 0)
        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
        self.label_3 = QtWidgets.QLabel(self.layoutWidget)
        self.label_3.setObjectName("label_3")
        self.horizontalLayout_2.addWidget(self.label_3)
        self.lineEdit_2 = QtWidgets.QLineEdit(self.layoutWidget)
        self.lineEdit_2.setObjectName("lineEdit_2")
        self.horizontalLayout_2.addWidget(self.lineEdit_2)
        self.layoutWidget1 = QtWidgets.QWidget(Form)
        self.layoutWidget1.setGeometry(QtCore.QRect(30, 70, 591, 41))
        self.layoutWidget1.setObjectName("layoutWidget1")
        self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.layoutWidget1)
        self.horizontalLayout_3.setContentsMargins(0, 0, 0, 0)
        self.horizontalLayout_3.setObjectName("horizontalLayout_3")
        self.horizontalLayout = QtWidgets.QHBoxLayout()
        self.horizontalLayout.setObjectName("horizontalLayout")
        self.lineEdit = QtWidgets.QLineEdit(self.layoutWidget1)
        self.lineEdit.setObjectName("lineEdit")
        self.horizontalLayout.addWidget(self.lineEdit)
        self.pushButton = QtWidgets.QPushButton(self.layoutWidget1)
        self.pushButton.setObjectName("pushButton")
        self.horizontalLayout.addWidget(self.pushButton)
        self.horizontalLayout_3.addLayout(self.horizontalLayout)
        self.pushButton_2 = QtWidgets.QPushButton(self.layoutWidget1)
        self.pushButton_2.setObjectName("pushButton_2")
        self.horizontalLayout_3.addWidget(self.pushButton_2)

        self.pushButton.clicked.connect(self.openImage)
        self.pushButton_2.clicked.connect(self.inferImage)

        self.retranslateUi(Form)
        QtCore.QMetaObject.connectSlotsByName(Form)

    def retranslateUi(self, Form):
        _translate = QtCore.QCoreApplication.translate
        Form.setWindowTitle(_translate("Form", "Form"))
        self.label.setText(_translate("Form", "白内障检测系统"))
        self.label_4.setText(_translate("Form", "诊断建议:"))
        self.label_3.setText(_translate("Form", "识别结果:"))
        self.pushButton.setText(_translate("Form", "打开文件"))
        self.pushButton_2.setText(_translate("Form", "开始识别"))

    def openImage(self):  # 选择本地图片上传
        global imgName  # 这里为了方便别的地方引用图片路径,我们把它设置为全局变量
        imgName, imgType = QFileDialog.getOpenFileName(self, "打开图片", "",
                                                       "*.jpg;*.png;;All Files(*)")  # 弹出一个文件选择框,第一个返回值imgName记录选中的文件路径+文件名,第二个返回值imgType记录文件的类型
        jpg = QtGui.QPixmap(imgName).scaled(self.label_2.width(),
                                            self.label_2.height())  # 通过文件路径获取图片文件,并设置图片长宽为label控件的长宽
        self.label_2.setPixmap(jpg)  # 在label控件上显示选择的图片
        self.lineEdit.setText(imgName)  # 显示所选图片的本地路径

    def inferImage(self):
        global imgName
        if imgName is None or imgName == '':
            QMessageBox.information(self, "Error!", "请先选择图片!", QMessageBox.Ok)
            return
        a1, a2 = test.infer(imgName)
        self.lineEdit_2.setText(a1)
        self.textEdit.setText(a2)
import asd_rc

四、结果展示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

通过本次案例,我们可以对深度学习设计程序的流程有一个简单清楚的认知,以便我们将来构建其它深度学习系统可以更加得心应手。


http://www.kler.cn/a/543355.html

相关文章:

  • 如何在Excel和WPS中进行翻译
  • 在 ARM64 架构系统离线安装 Oracle Java 8 全流程指南
  • mysql读写分离与proxysql的结合
  • 【嵌入式Linux应用开发基础】read函数与write函数
  • 利用HTML和css技术编写学校官网页面
  • 使用epoll与sqlite3进行注册登录
  • 基于Python flask-sqlalchemy的SQLServer数据库管理平台
  • WinForm 防破解、反编译设计文档
  • 2025年3月一区SCI-真菌生长优化算法Fungal growth optimizer-附Matlab免费代码
  • Citus的TPCC、TPCH性能测试
  • 时间敏感和非时间敏感流量的性能保证配置
  • 3dgs 2025 学习笔记
  • 【算法】【双指针】acwing算法基础 2816. 判断子序列
  • 懒人精灵内存插件(手游x86x64内存插件)
  • 芯盾时代数据安全产品体系,筑牢数据安全防线
  • Flowable:现代业务流程管理的解决方案
  • 深度学习新宠:卷积神经网络如何重塑人工智能版图?
  • Django 初学小案例:用户登录
  • ffmpeg -pix_fmts
  • 介绍几款免费的显示器辅助工具!
  • Linux虚拟机克隆
  • 【登录认证】
  • 异步加载和协程+Unity特殊文件夹
  • 不小心删除服务[null]后,git bash出现错误
  • Kimi-1.5与DeepSeek-R1:谁是AI推理的王者?
  • 脉冲当量含义