当前位置: 首页 > article >正文

丹摩 | 基于PyTorch的CIFAR-10图像分类实现


从创建实例开始的新项目流程

第一步:创建实例

  1. 登录 DAMODEL 平台。
  2. 创建一个 GPU 实例:
    在这里插入图片描述
    • GPU 配置:选择 NVIDIA H800 或其他可用高性能 GPU。
      在这里插入图片描述

    • 系统配置:推荐使用 Ubuntu 20.04,内存 16GB,硬盘 50GB。

    • 启动实例后,获取实例的 IP 地址。

    • 选择镜像
      在这里插入图片描述


第二步:连接实例

在这里插入图片描述

  1. 登录成功后,你会进入实例的终端界面。
    在这里插入图片描述
    在这里插入图片描述

第三步:更新系统和安装基础工具

  1. 更新系统:

    sudo apt update && sudo apt upgrade -y
    
  2. 安装 Python 和基础工具:

    sudo apt install python3 python3-pip git -y
    
  3. (可选)安装文本编辑器:

    sudo apt install vim nano -y
    

第四步:创建项目目录并配置环境

  1. 创建项目目录:

    mkdir ~/workspace/cifar10_project
    cd ~/workspace/cifar10_project
    
  2. 创建并激活虚拟环境:

    python3 -m venv venv
    source venv/bin/activate
    

    在这里插入图片描述
    前面出现venu则表示已经激活虚拟环境了

  3. 安装必要的 Python 包:

    pip install torch torchvision matplotlib
    

在这里插入图片描述

第五步:下载数据并初始化项目代码

  1. 创建 Python 脚本:

    vim train_cifar10.py
    
  2. 在文件中输入以下代码,加载 CIFAR-10 数据集并定义简单模型:

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import torch.nn as nn
    import torch.optim as optim
    
    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # 加载 CIFAR-10 数据集
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
    
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
    
    # 定义简单卷积神经网络
    class SimpleCNN(nn.Module):
        def __init__(self):
            super(SimpleCNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
            self.fc1 = nn.Linear(32 * 16 * 16, 10)
    
        def forward(self, x):
            x = self.pool(torch.relu(self.conv1(x)))
            x = x.view(-1, 32 * 16 * 16)
            x = self.fc1(x)
            return x
    
    # 初始化模型、损失函数和优化器
    net = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    # 模型训练
    for epoch in range(5):  # 训练 5 个周期
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader)}")
    
    print("Finished Training")
    
  3. 保存并退出(按下 Esc,然后输入 :wq)。


第六步:运行训练脚本

运行脚本进行模型训练:

python train_cifar10.py
  • 脚本会下载 CIFAR-10 数据集并训练模型。
  • 训练完成后会输出每个 epoch 的损失值。
    在这里插入图片描述

第七步:保存和测试模型

  1. 保存模型:在脚本末尾添加代码以保存训练好的模型:

    torch.save(net.state_dict(), "cifar10_model.pth")
    print("Model saved as cifar10_model.pth")
    
  2. 重新运行脚本以保存模型:

    python train_cifar10.py
    
  3. 检查是否生成了 cifar10_model.pth 文件:

    ls
    
  4. 测试模型(可选):加载保存的模型并在测试集上评估准确率:

    net.load_state_dict(torch.load("cifar10_model.pth"))
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    print(f"Accuracy on test dataset: {100 * correct / total}%")
    

第八步:清理和扩展

  1. 扩展功能

    • 使用更复杂的模型(如 ResNet)。
    • 尝试使用 Adam 优化器提高性能。
    • 可视化训练过程或模型预测结果。
  2. 清理资源

    • 如果完成训练并不再需要 GPU 计算,记得停止或删除实例以节省费用。

\


http://www.kler.cn/a/412646.html

相关文章:

  • 《热带气象学报》
  • Flink CDC 使用实践以及遇到的问题
  • hping3工具介绍及使用方法
  • java写一个石头剪刀布小游戏
  • 概率论中交并集的公式
  • 代码随想录算法训练营day46|动态规划09
  • 第三方数据库连接免费使用和安装
  • 白光干涉仪:表面粗糙度形貌台阶高测量解决方案
  • Flutter 共性元素动画
  • 工业网络安全 智能电网,SCADA和其他工业控制系统等关键基础设施的网络安全(总结)
  • 无法通过外网连接访问mysql问题排查
  • 如何通过终端连接无线网
  • echarts使用示例
  • laravel官方升级引起的报错问题解决
  • 原子类、AtomicLong、AtomicReference、AtomicIntegerFieldUpdater、LongAdder
  • [python]poetry安装和使用
  • Vue前端面试进阶(五)
  • day29|leetcode 134. 加油站 , 135. 分发糖果 ,860.柠檬水找零 , 406.根据身高重建队列
  • 模型压缩理论简介及剪枝与稀疏化在 征程 5 上实践
  • 检测到“runtimelibrary”的不匹配项: 值“mtd_staticdebug”不匹配值“mdd_dynamic”
  • 基于MFC实现的俄罗斯方块游戏
  • cgroup简介
  • 深入理解 TypeScript:联合类型与交叉类型的应用
  • 如何编写出色的技术文档
  • shell(4)脚本与用户交互以及if条件判断
  • 第三十二章 UDP 客户端 服务器通信