【深度学习】PyTorch :调用残差网络(ResNet)
ResNet (Residual Network) 是由 Microsoft Research 的 Kaiming He 等人在 2015 年提出的一种深度学习模型结构。它解决了随着网络深度增加而导致的梯度消失和退化问题。传统的深层网络可能由于信息难以有效传递,导致模型性能下降,而 ResNet 通过引入残差连接(skip connections),使信息可以跨层直接传递,从而缓解了这一问题。
基本原理
ResNet 的核心思想是学习残差函数而不是直接学习期望的映射函数。具体来说,假设希望学习的目标映射为 H(x) ,ResNet 让每个模块学习一个残差函数 F(x)=H(x)−x ,这样原始映射变成 H(x)=F(x)+x 。这种设计使得梯度更容易反向传播,有助于训练更深层的网络。
常见的 ResNet 结构包括 ResNet-18、ResNet-34、ResNet-50、ResNet-101 等,它们通过不同的层数适应从简单到复杂的任务需求。
导入必要的包
确保安装 PyTorch 和 torchvision:
pip install torch torchvision
在代码中导入相关模块:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
实例化预训练 ResNet 模型
通过 torchvision.models
获取预训练的 ResNet 模型:
# 实例化 ResNet-50 模型,并使用预训练权重
model = models.resnet50(pretrained=True)
# 切换模型到计算设备(GPU 或 CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
修改输出层适应新任务
如果要将 ResNet 应用于自定义分类任务,需要修改其最后的全连接层:
# 假设新任务有 10 个类别
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
数据预处理与加载
使用 torchvision.transforms
对图像数据进行预处理:
# 定义数据变换
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
模型训练
定义损失函数和优化器,并进行模型训练:
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
num_epochs = 5
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
# 清零梯度
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
print('训练完成!')
测试模型性能
在测试集上评估模型的分类准确率:
# 加载测试数据
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
# 测试模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'测试准确率: {accuracy:.2f}%')
总结
通过上述步骤,您可以在 PyTorch 中快速使用预训练的 ResNet 模型,并根据不同任务需求进行定制和优化。ResNet 强大的残差学习能力使其成为许多计算机视觉任务的首选模型。