当前位置: 首页 > article >正文

利用 Resnet50 重新训练,完成宠物数据集的识别,附源代码。。

如果你对深度学习有所了解,知道神经网络可以识别图片,但还没自己动手训练过模型,这篇文章会非常适合你。

这篇文章将使用 PyTorch 和 ResNet50,基于 Oxford-IIIT Pet 数据集(37 类宠物)完成一个完整的训练过程。

这个方法也可以应用到你自己的数据集上,比如识别不同种类的花或物体。

接下来,带你一步步完成这个任务。

Attention:全网最全的 AI 小白到 AI 大神的天梯成长学习路线,几十万原创专栏和硬核视频,点击这里查看:AI小白到AI大神的天梯之路

什么是 ResNet50,为什么选择它?

ResNet50 是一个深度卷积神经网络,包含 50 层,设计用来处理图像分类任务。

它在 ImageNet 数据集上表现优异,能识别 1000 种物体。

我们今天的目标是重新训练它,让它学会识别新的类别——37 种宠物。

选择 ResNet50 的理由很简单——

  • 成熟的结构,它已经被广泛验证,适合大多数图像分类任务。
  • 开箱即用:PyTorch 提供了现成的实现,省去自己设计的麻烦。
  • 高效性:即使从零开始训练,也能得到不错的结果。

下面,我们将训练过程拆成几个关键步骤,逐步讲解。

训练 ResNet50 的四大步骤

步骤 1:准备数据

模型训练的第一步是准备数据。

Oxford-IIIT Pet 数据集包含大量宠物照片,我们需要调整它们的格式,确保模型能正确处理。

代码是这样实现的:

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 将图像调整为 224x224 像素
    transforms.ToTensor(),          # 将图像转换为 Tensor 格式
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])  # 标准化像素值])   
  • ResizeResNet50 的输入要求是 224x224,所有图像需要统一到这个尺寸。
  • ToTensor将图片从普通格式转为模型能处理的数字格式(范围 0 到 1)。
  • Normalize用 ImageNet 的均值和标准差标准化数据,帮助模型更快收敛。

接着,用 DataLoader 将数据分成小批次:

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

这里 batch_size=32 表示每次处理 32 张图片,shuffle=True 打乱顺序,避免模型记住数据的排列。

步骤 2:搭建模型——调整 ResNet50 的结构

ResNet50 是一个现成的模型,但我们需要根据任务调整它。

原始的 ResNet50 输出 1000 类,而我们的数据集只有 37 类,因此需要修改最后一层。

代码实现如下:

model = torchvision.models.resnet50(weights=None)  # 初始化 ResNet50,不使用预训练权重
model.fc = nn.Linear(model.fc.in_features, 37)     # 将全连接层改为 37 类输出
model = model.to(device)      # 转移到 GPU 或 CPU
  • weights=None表示将从零开始训练模型。
  • model.fc这一行代码修改了模型最后一层(全连接层),将输出特征数改为 37 个,对应 37 类宠物。如果你有自己的数据集,且分类数量与原始模型不一致,也需要进行类似的修改。
  • to(device)根据设备(GPU 或 CPU)运行模型,GPU 会显著加速训练。

步骤 3:定义学习方式

模型需要知道如何学习以及学习步长是什么样的,这样才能优化模型参数的调整过程。

这个过程主要涉及损失函数和优化器。

损失函数衡量的是模型预测值与真实答案之间的差距,优化器则负责调整模型的参数。

用代码中是这样定义的:

criterion = nn.CrossEntropyLoss()          # 交叉熵损失,用于分类任务
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam 优化器,学习率 0.001
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 学习率调度器
  • 损失函数采用的是交叉熵损失函数,该函数是多分类任务的标准选择。
  • 优化器Adam 是一种高效的优化算法,lr=0.001 是初始学习率。
  • 调度器每 5 个 epoch,学习率乘以 0.1,逐步降低以稳定训练。

步骤 4:训练与测试——让模型学习和验证

训练其实就是让模型反复调整自己参数的过程,验证则是检查训练的效果。

训练和验证的逻辑分别在两个函数中实现。

训练函数:

def train(epoch):
    model.train()  # 进入训练模式
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()  # 清零梯度
        outputs = model(inputs)  # 前向传播
        loss = criterion(outputs, targets)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数
  • train()激活模型的训练模式(启用 dropout/BN 层的全局统计功能)。
  • 流程模型预测 -> 计算损失 -> 调整参数。

测试函数:

def test(epoch):
    model.eval()  # 进入测试模式
    with torch.no_grad():  # 关闭梯度计算
        for inputs, targets in test_loader:
            outputs = model(inputs)
            # 计算准确率...
  • eval()切换到测试模式,关闭训练时的随机性(Dropout, BN 不再进行全局统计)。
  • no_grad()节省内存,提高测试效率。

主循环运行 20 个 epoch,每次训练后测试,并保存最佳模型:

for epoch in range(1, 21):
    train(epoch)
    test_acc = test(epoch)
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), "best_pet_model.pth")

训练效果

运行完整的代码后,你会看到类似这样的结果:

Epoch 1 | Train Acc: 50.23%Epoch 1 | Test Acc: 52.10%...Best Test Accuracy: 85.67%

这表示模型在测试集上的最高准确率达到 85.67%。

如果效果不理想,可以尝试下面的改进方法。

改进建议

使用预训练权重

weights=None 改为 weights='DEFAULT',利用 ImageNet 的经验加速训练。

数据增强

transform 中加入 transforms.RandomHorizontalFlip(),增加数据多样性。

调整参数

尝试不同的学习率(如 0.0001)或 batch_size(如 64),找到最佳组合。

通过以上的四个步骤——准备数据、搭建模型、设定规则、训练测试,你就可以用 ResNet50 训练自己的数据集了。

这个过程并不复杂,只要理解每个部分的逻辑,就能灵活应用到其他任务上。

如果你有自己的数据集,不妨试一试。

宠物训练的完整代码见这里:https://github.com/dongdongcan/ai_model_samples/tree/main/resnet50_train_oxford_iiit_pet

备注,本文的完整代码最好在 GPU 环境下运行。


http://www.kler.cn/a/593314.html

相关文章:

  • 主流区块链平台对 EVM 的依赖情况分类说明
  • C#语言的响应式设计
  • OPC服务器开发之WtOPCSvr(3)
  • Linux 实时同步服务实现(Rsync 结合 Inotify)
  • 在鸿蒙NEXT中开发一个2048小游戏
  • umi自带的tailwindcss修改为手动安装
  • PyQt6嵌入HTML5内容教程
  • STM32-汇编
  • Kafka 的消息机制以及消息丢失等问题
  • LeetCode hot 100—每日温度
  • Flutter:页面滚动,导航栏背景颜色过渡动画
  • 鸿蒙NEXT项目实战-百得知识库01
  • 无线数据网关 自动化测控的LoRa-4G混合网络 串口升级、信号扩展 高效物联传输网络
  • 实验3:Vue.js组件实验
  • 蓝桥杯算法分享:征服三座算法高峰
  • vue3_弹窗数字表单组件
  • docker 部署elk 设置账号密码
  • MySQL数据高效同步到Elasticsearch的四大方案
  • 利用大语言模型生成的合成数据训练YOLOv12:提升商业果园苹果检测的精度与效率
  • 【QT】】qcustomplot的初步使用二