当前位置: 首页 > article >正文

PyTorch 实现手写数字识别

PyTorch 实现手写数字识别

在本教程中,我们将使用 PyTorch 实现经典的手写数字识别任务。我们将使用 MNIST 数据集,这是一个包含手写数字的图像数据集。我们将介绍如何使用 PyTorch 构建、训练和评估一个简单的卷积神经网络(CNN)模型来进行手写数字识别。

1. 项目概述

手写数字识别任务是通过训练模型,让其能够识别手写数字图像并输出正确的数字类别(0-9)。MNIST 数据集包含 28x28 像素的灰度图像,每个图像代表一个手写数字。

我们将使用以下步骤:

  1. 加载 MNIST 数据集
  2. 构建一个卷积神经网络(CNN)
  3. 训练模型
  4. 评估模型性能
  5. 进行测试预测

2. 官方文档链接

  • PyTorch 官方文档
  • MNIST 数据集链接

3. 安装 PyTorch 和依赖库

首先,确保您已经安装了 PyTorch 和相关依赖库。如果没有安装,可以运行以下命令:

pip install torch torchvision matplotlib

4. 加载 MNIST 数据集

我们将使用 torchvision 提供的 MNIST 数据集。它包含 60,000 个训练样本和 10,000 个测试样本。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# 数据预处理:将图像转换为张量,并进行标准化
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 下载并加载 MNIST 训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# 加载数据集
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# 查看数据集的大小
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")

# 可视化部分样本
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
plt.figure(figsize=(10, 3))
for i in range(6):
    plt.subplot(1, 6, i + 1)
    plt.imshow(example_data[i][0], cmap='gray')
    plt.title(f"Label: {example_targets[i]}")
    plt.axis('off')
plt.show()

说明

  • transforms.Compose:我们将图像转换为 PyTorch 张量,并将像素值标准化为 [-1, 1] 的范围。
  • DataLoader:用于将数据集加载为批次,并打乱数据顺序以便训练时使用。

5. 构建卷积神经网络(CNN)

我们将构建一个简单的 CNN 模型,用于手写数字识别。该模型将包含两个卷积层和两个全连接层。

import torch.nn as nn
import torch.nn.functional as F

# 定义 CNN 模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 卷积层1: 输入通道为1(灰度图),输出通道为16,卷积核大小为3x3
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3)
        # 卷积层2: 输入通道为16,输出通道为32,卷积核大小为3x3
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        # 全连接层1: 输入为32*5*5(展平后的特征图),输出为128
        self.fc1 = nn.Linear(32 * 5 * 5, 128)
        # 全连接层2: 输入为128,输出为10(10个类别)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        # 卷积层 + ReLU + 最大池化层
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        # 展平成一维向量
        x = x.view(-1, 32 * 5 * 5)
        # 全连接层 + ReLU
        x = F.relu(self.fc1(x))
        # 输出层
        x = self.fc2(x)
        return x

# 实例化模型
model = CNN()
print(model)

说明

  • conv1conv2:卷积层用于提取图像特征。第一个卷积层从 1 个输入通道(灰度图像)转换为 16 个特征图,第二个卷积层将 16 个特征图转换为 32 个特征图。
  • max_pool2d:最大池化层,用于下采样特征图,将特征图尺寸减半。
  • fc1fc2:全连接层,用于将卷积层提取到的特征进行分类。

6. 训练模型

我们将定义损失函数和优化器,然后在训练数据集上训练模型。

import torch.optim as optim

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 将模型移动到 GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 训练模型
epochs = 5
for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}")

print("训练完成!")

说明

  • CrossEntropyLoss:用于分类任务的损失函数,适用于多分类问题。
  • optimizer:使用 Adam 优化器,能够自动调整学习率并加快收敛速度。
  • 训练过程包括前向传播、损失计算、反向传播和参数更新。

7. 评估模型性能

在训练完成后,我们将使用测试数据集来评估模型的性能,计算模型在测试集上的准确率。

# 测试模型
model.eval()  # 切换到评估模式
correct = 0
total = 0

with torch.no_grad():  # 关闭梯度计算
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'测试集上的准确率: {100 * correct / total:.2f}%')

说明

  • model.eval():在评估模型时关闭 dropout 和 batch normalization。
  • torch.no_grad():关闭梯度计算以提高测试阶段的效率。

8. 进行预测

最后,我们可以使用训练好的模型对手写数字图像进行预测。

# 从测试集中取出一个样本
example_data, example_target = next(iter(test_loader))
example_data = example_data.to(device)

# 使用模型进行预测
model.eval()
with torch.no_grad():
    output = model(example_data)

# 可视化预测结果
plt.figure(figsize=(10, 3))
for i in range(6):
    plt.subplot(1, 6, i + 1)
    plt.imshow(example_data[i][0].cpu(), cmap='gray')
    plt.title(f"预测: {torch.argmax(output[i]).item()}")
    plt.axis('off')
plt.show()

说明

  • 取出测试集中的一批样本进行预测,并可视化模型的预测结果。

9. 总结

在本教程中,我们使用 PyTorch 实现了手写数字识别任务,构建了一个简单的卷积神经网络(CNN),并在 MNIST 数据集上进行了训练和评估。通过此项目,您可以了解如何加载数据、构建模型、训练、评估和测试 PyTorch 模型。

10. 改进方向

  • 增加网络深度:可以增加卷积层和全连接层的

数量,提高模型的表现。

  • 使用数据增强:通过数据增强技术(旋转、缩放等),可以提高模型的泛化能力。
  • 应用在其他数据集:除了 MNIST,还可以将模型应用到其他数据集,如 FashionMNIST、CIFAR-10 等。

http://www.kler.cn/news/314606.html

相关文章:

  • 2024华为杯数模CDEF成品文章!【配套完整解题代码+数据处理】
  • 一文读懂 JS 中的 Map 结构
  • 图形化编程012(变量-倒计时)
  • 【JVM原理】运行时数据区(内存结构)
  • 前端框架的比较与选择详解
  • 数据库提权【笔记总结】
  • 计算机毕业设计 社区医疗服务系统的设计与实现 Java实战项目 附源码+文档+视频讲解
  • web基础—dvwa靶场(四)​File Inclusion
  • 电脑文件防泄密软件哪个好?这六款软件建议收藏【精选推荐】
  • MQ(RabbitMQ)笔记
  • Flutter 约束布局
  • 充电桩项目:前端实现
  • Ubuntu 安装 OpenGL 开发库
  • leetcode第十四题:最长公共前缀
  • 12.Java基础概念-面向对象-static
  • 2024“华为杯”中国研究生数学建模竞赛(A题)深度剖析_数学建模完整过程+详细思路+代码全解析
  • 无线安全(WiFi)
  • 【MySQ】在MySQL里with 的用法
  • 【技术解析】消息中间件MQ:从原理到RabbitMQ实战(深入浅出)
  • 计算机毕业设计之:基于微信小程序的校园流浪猫收养系统(源码+文档+讲解)
  • WEB 编程:富文本编辑器 Quill 配合 Pico.css 样式被影响的问题
  • vue配置axios
  • 使用Java实现高效用户行为监控系统
  • 二叉树(二)深度遍历和广度遍历
  • Redis的三种持久化方法详解
  • Spring Boot实战:使用策略模式优化商品推荐系统
  • linux用户管理运行级别找回root密码
  • 【Java注解】
  • Docker指令学习1
  • 【Kubernetes】常见面试题汇总(二十七)