当前位置: 首页 > article >正文

使用 PyTorch 实现简化版 GoogLeNet 进行 MNIST 图像分类

介绍

        本文将介绍如何使用 PyTorch 实现一个简化版的 GoogLeNet 网络来进行 MNIST 图像分类。GoogLeNet 是 Google 提出的深度卷积神经网络(CNN),其通过 Inception 模块大大提高了计算效率并提升了分类性能。我们将实现一个简化版的 GoogLeNet,用于处理 MNIST 数据集,该数据集由手写数字图片组成,适合用于小规模的图像分类任务。

项目结构

        我们将代码分为两个部分:

  • 训练脚本 train.py:包括数据加载、模型构建、训练过程等。
  • 测试脚本 test.py:用于加载训练好的模型并在测试集上评估性能。

项目依赖

        在开始之前,我们需要安装以下 Python 库:

  • torch:PyTorch 深度学习框架
  • torchvision:提供数据加载和图像变换功能
  • matplotlib:用于可视化

        可以通过以下命令安装所有依赖:

pip install -r requirements.txt

  requirements.txt 文件内容如下:

torch==2.0.1
torchvision==0.15.0
matplotlib==3.6.3

数据预处理与加载

1. 数据加载和预处理

        在训练模型之前,我们需要对 MNIST 数据集进行预处理。以下是数据加载和预处理的代码:

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def get_data_loader(batch_size=64, train=True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 正规化到 [-1, 1] 范围
    ])

    dataset = datasets.MNIST(root='./data', train=train, download=True, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

        这里,我们使用了 transforms.Compose 来进行数据预处理,包括将图像转换为 Tensor 格式,并进行归一化处理。


训练部分:train.py

2. 模型定义:简化版 GoogLeNet

        为了在 MNIST 数据集上训练,我们构建了一个简化版的 GoogLeNet,包含三个 Inception 模块和一个全连接层。每个 Inception 模块由一个卷积层和一个最大池化层组成。简化的 GoogLeNet 模型如下:

import torch.nn as nn

class SimpleGoogLeNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleGoogLeNet, self).__init__()

        # 第一个 Inception 模块
        self.inception1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # 第二个 Inception 模块
        self.inception2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # 第三个 Inception 模块
        self.inception3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        # 分类器:全连接层 + Dropout 层
        self.fc = nn.Sequential(
            nn.Linear(128 * 3 * 3, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.inception1(x)
        x = self.inception2(x)
        x = self.inception3(x)
        x = x.view(x.size(0), -1)  # 展平输入
        x = self.fc(x)
        return x

3. 训练函数

        训练过程包括前向传播、反向传播和优化。我们将使用 Adam 优化器和 交叉熵损失 来训练模型:

import torch.optim as optim
from tqdm import tqdm

def train_epoch(model, device, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    with tqdm(train_loader, desc="Training", unit="batch", ncols=100) as pbar:
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.set_postfix(loss=running_loss / (total // 64), accuracy=100 * correct / total)

    return running_loss / len(train_loader), 100 * correct / total

4. 训练脚本:train.py

        训练脚本将包括模型的定义、数据加载、训练过程等:

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from model import SimpleGoogLeNet  # 假设模型在 model.py 文件中

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGoogLeNet().to(device)
    
    train_loader = get_data_loader(batch_size=64, train=True)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    epochs = 10
    for epoch in range(epochs):
        loss, accuracy = train_epoch(model, device, train_loader, criterion, optimizer)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}, Accuracy: {accuracy:.2f}%")
        
    torch.save(model.state_dict(), "simplified_googlenet.pth")  # 保存模型

if __name__ == '__main__':
    train_model()


测试部分:test.py

5. 测试函数

        在测试阶段,我们将使用 torch.no_grad() 禁用梯度计算,提高推理速度,并计算模型在测试集上的准确率:

def test_model(model, device, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')

6. 测试脚本:test.py

        测试脚本将加载训练好的模型并对测试集进行评估:

import torch
from model import SimpleGoogLeNet  # 假设模型在 model.py 文件中
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def get_test_loader(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # 正规化到 [-1, 1] 范围
    ])
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    return DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SimpleGoogLeNet().to(device)
    model.load_state_dict(torch.load("simplified_googlenet.pth"))  # 加载训练好的模型
    
    test_loader = get_test_loader(batch_size=64)
    test_model(model, device, test_loader)

if __name__ == '__main__':
    main()

总结

        本文介绍了如何使用 PyTorch 实现简化版的 GoogLeNet,并将代码分为训练(train.py)和测试(test.py)部分。在训练脚本中,我们定义了一个简化版的 GoogLeNet,训练模型并保存训练结果。而在测试脚本中,我们加载训练好的模型并在测试集上进行评估。

        通过这些步骤,我们能够快速地实现一个高效的图像分类模型,并在 MNIST 数据集上进行训练与测试。

完整项目
GitHub - qxd-ljy/GoogLeNet-PyTorch: 使用PyTorch实现GooLeNet进行MINST图像分类使用PyTorch实现GooLeNet进行MINST图像分类. Contribute to qxd-ljy/GoogLeNet-PyTorch development by creating an account on GitHub.icon-default.png?t=O83Ahttps://github.com/qxd-ljy/GoogLeNet-PyTorchGitHub - qxd-ljy/GoogLeNet-PyTorch: 使用PyTorch实现GooLeNet进行MINST图像分类使用PyTorch实现GooLeNet进行MINST图像分类. Contribute to qxd-ljy/GoogLeNet-PyTorch development by creating an account on GitHub.icon-default.png?t=O83Ahttps://github.com/qxd-ljy/GoogLeNet-PyTorch

        希望这篇博客对你有所帮助,欢迎继续探索 PyTorch 和深度学习的更多应用!


http://www.kler.cn/a/400252.html

相关文章:

  • 深入理解Go语言并发编程:从基础到实践
  • 利用Python爬虫获取淘宝店铺详情
  • 第6章详细设计 -6.7 PCB工程需求表单
  • Cyberchef配合Wireshark提取并解析TCP/FTP流量数据包中的文件
  • vs2022搭建opencv开发环境
  • 《Java核心技术 卷I》用户界面中首选项API
  • C# 面向对象
  • MySQL45讲 第二十五讲 高可用性深度剖析:从主备原理到策略选择
  • 淘宝客结合C#使用WebApi和css绘制商品图片
  • 界面控件DevExpress WinForms v24.2新功能预览 - 支持.NET 9
  • 社交电商的优势及其与 AI 智能名片小程序、S2B2C 商城系统的融合发展
  • Java篇String类的常见方法
  • 基于YOLOv8深度学习的智慧交通非机动车驾驶员头盔佩戴检测系统
  • Matlab实现白鲸优化算法优化随机森林算法模型 (BWO-RF)(附源码)
  • 在Keil中使用ST-LINK烧录STM32程序指南
  • 聚焦 AUTO TECH 2025华南展:探索新能源汽车发展新趋势
  • 美赛优秀论文阅读--2023C题
  • Spring Boot汽车资讯:数字化时代的驾驶
  • 前端性能优化深入解析:提升用户体验的几个关键点
  • 工具类-基于 axios 的 http 请求工具 Request
  • ELK8.15.4搭建开启安全认证
  • 基于Vue3与ABP vNext 8.0框架实现耗时业务处理的进度条功能
  • 常见网络厂商设备默认用户名/密码大全
  • 移动端web页面调用原生jsbridge的封装
  • java ssm 高速公路管理系统 公路收费管理 高速收费管理 源码 jsp
  • 【Android】Proxyman 抓 HTTP 数据包