当前位置: 首页 > article >正文

【深度学习】PyTorch :调用残差网络(ResNet)

ResNet (Residual Network) 是由 Microsoft Research 的 Kaiming He 等人在 2015 年提出的一种深度学习模型结构。它解决了随着网络深度增加而导致的梯度消失和退化问题。传统的深层网络可能由于信息难以有效传递,导致模型性能下降,而 ResNet 通过引入残差连接(skip connections),使信息可以跨层直接传递,从而缓解了这一问题。

基本原理

ResNet 的核心思想是学习残差函数而不是直接学习期望的映射函数。具体来说,假设希望学习的目标映射为 H(x) ,ResNet 让每个模块学习一个残差函数 F(x)=H(x)−x ,这样原始映射变成 H(x)=F(x)+x 。这种设计使得梯度更容易反向传播,有助于训练更深层的网络。

常见的 ResNet 结构包括 ResNet-18、ResNet-34、ResNet-50、ResNet-101 等,它们通过不同的层数适应从简单到复杂的任务需求。

导入必要的包

确保安装 PyTorch 和 torchvision:

pip install torch torchvision

在代码中导入相关模块:

import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

实例化预训练 ResNet 模型

通过 torchvision.models 获取预训练的 ResNet 模型:

# 实例化 ResNet-50 模型,并使用预训练权重
model = models.resnet50(pretrained=True)

# 切换模型到计算设备(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

修改输出层适应新任务

如果要将 ResNet 应用于自定义分类任务,需要修改其最后的全连接层:

# 假设新任务有 10 个类别
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

数据预处理与加载

使用 torchvision.transforms 对图像数据进行预处理:

# 定义数据变换
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

模型训练

定义损失函数和优化器,并进行模型训练:

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # 清零梯度
        optimizer.zero_grad()

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

print('训练完成!')

测试模型性能

在测试集上评估模型的分类准确率:

# 加载测试数据
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# 测试模型
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'测试准确率: {accuracy:.2f}%')

总结

通过上述步骤,您可以在 PyTorch 中快速使用预训练的 ResNet 模型,并根据不同任务需求进行定制和优化。ResNet 强大的残差学习能力使其成为许多计算机视觉任务的首选模型。


http://www.kler.cn/a/506479.html

相关文章:

  • 【网络编程】基础知识
  • 大疆最新款无人机发布,可照亮百米之外目标
  • 【网络 MAC 学习专栏 -- 如何理解 PHY 的 Link Up】
  • 优化 Vue项目中 app.js 文件过大,初始化加载过慢、带宽占用过大等问题
  • 深度学习电影推荐-CNN算法
  • Docker--Docker Compose(容器编排)
  • nginx反向代理http 和 https(案例)
  • 域名劫持是怎么回事?怎么解决?
  • docker安装和测试redis步骤
  • 8.BMS SOC的算法总结
  • 【20250115】Nature子刊:柔性生物传感与深度学习结合的上肢运动增强外骨骼机器人...
  • 【Rust自学】12.4. 重构 Pt.2:错误处理
  • 酷柚易汛ERP 2025-01-16系统升级日志
  • 【C++ 类和对象 进阶篇】—— 逻辑森林的灵动精灵,舞动类与对象的奇幻圆舞曲
  • elrond32
  • QT跨平台应用程序开发框架(3)—— 信号和槽
  • 【深度学习】关键技术-优化算法(Optimization Algorithms)详解与代码示例
  • shell练习(3)
  • SQL-leetcode—626. 换座位
  • opencv_图像处理_去噪声_采用中值滤波
  • 设计模式相关面试
  • php审计1-extract函数变量覆盖
  • 数据仓库的复用性:模型层面通用指标体系、参数化模型、版本化管理
  • Rust中的Rc. Cell, RefCell
  • redis-6.26主从配置
  • 【AI落地】如何创建字节的coze扣子工作流 ——以“批量获取抖音视频文案”为例