当前位置: 首页 > article >正文

深度学习-医学影像诊断

以下以使用深度学习进行医学影像(如 X 光片)的肺炎诊断为例,为你展示基于 PyTorch 框架的代码实现。我们将构建一个简单的卷积神经网络(CNN)模型,使用公开的肺炎 X 光影像数据集进行训练和评估。

1. 安装必要的库

pip install torch torchvision numpy matplotlib pandas

2. 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.ImageFolder(root='path/to/train_data', transform=transform)
test_dataset = datasets.ImageFolder(root='path/to/test_data', transform=transform)

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 定义简单的 CNN 模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(32 * 56 * 56, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 32 * 56 * 56)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型、损失函数和优化器
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
train_losses = []
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}')

# 绘制训练损失曲线
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

# 评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

3. 代码解释

  • 数据预处理

    • 使用 transforms.Compose 定义了一系列的数据预处理操作,包括调整图像大小、转换为张量和归一化。
    • transforms.Resize((224, 224)) 将图像调整为 224x224 大小。
    • transforms.ToTensor() 将图像转换为张量。
    • transforms.Normalize 对图像进行归一化处理。
  • 数据集加载

    • 使用 datasets.ImageFolder 加载训练集和测试集,需要将 path/to/train_datapath/to/test_data 替换为实际的数据集路径。
    • DataLoader 用于创建数据加载器,方便批量加载数据。
  • 模型定义

    • SimpleCNN 类定义了一个简单的卷积神经网络模型,包含两个卷积层、两个池化层和两个全连接层。
  • 训练过程

    • 使用 nn.CrossEntropyLoss 作为损失函数,optim.Adam 作为优化器。
    • 在每个 epoch 中,遍历训练数据,计算损失并进行反向传播和参数更新。
  • 模型评估

    • 将模型设置为评估模式(model.eval()),在测试集上进行预测,并计算准确率。

4. 注意事项

  • 数据集:你需要准备合适的医学影像数据集,并将其按照训练集和测试集进行划分,每个类别放在不同的文件夹中。
  • 模型复杂度:这里的 SimpleCNN 是一个简单的模型,在实际应用中,可能需要使用更复杂的预训练模型(如 ResNet、DenseNet 等)来提高诊断准确率。
  • 计算资源:训练深度学习模型需要一定的计算资源,建议在 GPU 上运行以提高训练速度。可以使用 torch.cuda.is_available() 检查是否有可用的 GPU,并将模型和数据移动到 GPU 上进行训练。例如:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
images, labels = images.to(device), labels.to(device)

如果你有其他具体需求,如使用不同的模型架构、处理不同类型的医学影像等,可以进一步调整代码。


http://www.kler.cn/a/542750.html

相关文章:

  • ubuntu如何设置停止程序自动更新
  • 【Obsidian】当笔记接入AI,Copilot插件推荐
  • bazel 小白理解
  • 国产编辑器EverEdit - 迷你查找
  • Vue3+codemirror6实现公式(规则)编辑器
  • Log4j定制JSON格式日志输出
  • Go 1.4操作符指针理解
  • 《从入门到精通:蓝桥杯编程大赛知识点全攻略》(十二)-航班时间、日志统计、献给阿尔吉侬的花束
  • NLP面试-Transformer
  • 【后端发展路径】基础技术栈、工程能力进阶、高阶方向、职业发展路径
  • vue3自定义loading加载动画指令
  • Java集合List详解(带脑图)
  • 基于微信小程序的刷题系统的设计与实现springboot+论文源码调试讲解
  • 开发中用到的设计模式
  • Excel 笔记
  • 【哇! C++】第一个C++语言程序
  • docker compose部署dragonfly
  • 《pytorch》——优化器的解析和使用
  • 【含文档+PPT+源码】基于微信小程序的在线考试与选课教学辅助系统
  • Goland的context原理(存在问题,之前根本没有了解,需要更加深入了解)
  • 前端首屏时间优化方案
  • Python实现机器学习舆情分析项目的经验分享
  • Centos10 Stream 基础配置
  • 数据结构 双链表的模拟实现
  • 【前端】【面试】ref与reactive的区别
  • C# OpenCV机器视觉:模仿Halcon各向异性扩散滤波