基于 PyTorch 的 MNIST 手写数字分类模型
一、概述
本代码使用 PyTorch 框架构建了一个简单的神经网络模型,用于解决 MNIST 手写数字分类任务。代码主要包括数据的加载与预处理、神经网络模型的构建、损失函数和优化器的定义、模型的训练、评估以及最终模型的保存等步骤。
二、依赖库
torch
:PyTorch 深度学习框架的核心库,提供了张量操作、自动求导等功能。torch.nn
:PyTorch 的神经网络模块,包含了各种神经网络层、损失函数等。torch.optim
:PyTorch 的优化器模块,用于更新模型的参数。torch.utils.data
:提供了数据加载和处理的工具,如DataLoader
类。torchvision
:PyTorch 的计算机视觉库,包含了常用的数据集、图像变换等。
三、代码详解
1. 加载和预处理数据
python
# 使用MNIST数据集(手写数字分类任务)
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化到 [-1, 1]
])
# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
- 图像变换:定义了一个
transform
对象,使用transforms.Compose
组合了两个操作。transforms.ToTensor()
将图像转换为 PyTorch 的张量格式;transforms.Normalize((0.5,), (0.5,))
对图像进行归一化处理,将像素值范围从[0, 1]
映射到[-1, 1]
。 - 数据集加载:使用
datasets.MNIST
分别加载训练集和测试集。root='./data'
指定数据的存储路径;train=True
表示加载训练集,train=False
表示加载测试集;download=True
表示如果数据不存在则自动下载;transform=transform
表示对加载的图像应用上述定义的变换。 - 数据加载器:使用
DataLoader
创建训练集和测试集的数据加载器。batch_size=64
表示每个批次包含 64 个样本;shuffle=True
表示在训练时对数据进行随机打乱,以增加模型的泛化能力;shuffle=False
表示在测试时不打乱数据,以便于结果的评估。
2. 构建神经网络模型
python
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten() # 将图像展平
self.fc1 = nn.Linear(28 * 28, 128) # 全连接层1
self.fc2 = nn.Linear(128, 64) # 全连接层2
self.fc3 = nn.Linear(64, 10) # 输出层
def forward(self, x):
x = self.flatten(x)
x = torch.relu(self.fc1(x)) # 激活函数ReLU
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = NeuralNetwork()
- 定义神经网络类:创建了一个名为
NeuralNetwork
的类,继承自nn.Module
,这是 PyTorch 中所有神经网络模型的基类。 - 初始化函数:在
__init__
方法中,定义了模型的各个层。nn.Flatten()
用于将输入的图像张量展平为一维向量;nn.Linear(28 * 28, 128)
、nn.Linear(128, 64)
和nn.Linear(64, 10)
分别定义了三个全连接层,输入和输出维度根据任务需求设置。 - 前向传播函数:
forward
方法定义了模型的前向传播过程。输入数据x
首先被展平,然后依次通过三个全连接层,并在中间层应用 ReLU 激活函数,最后返回输出结果。 - 创建模型实例:
model = NeuralNetwork()
创建了一个NeuralNetwork
类的实例,即我们的神经网络模型。
3. 定义损失函数和优化器
python
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器
- 损失函数:使用
nn.CrossEntropyLoss()
定义了交叉熵损失函数,用于计算模型预测结果与真实标签之间的差异。交叉熵损失函数常用于多分类任务。 - 优化器:选择
optim.Adam
作为优化器,model.parameters()
表示要更新的模型参数,lr=0.001
设置学习率为 0.001,学习率控制每次参数更新的步长。
4. 训练模型
python
def train(model, train_loader, criterion, optimizer, epochs=5):
model.train() # 设置为训练模式
for epoch in range(epochs):
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad() # 梯度清零
outputs = model(images) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
running_loss += loss.item()
print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}")
train(model, train_loader, criterion, optimizer, epochs=5)
- 训练函数定义:
train
函数接受模型、训练数据加载器、损失函数、优化器和训练轮数epochs
作为参数。 - 设置训练模式:
model.train()
将模型设置为训练模式,此时模型中的一些层(如 Dropout、BatchNorm 等)会处于训练状态。 - 训练循环:外层循环迭代
epochs
次,内层循环遍历训练数据加载器中的每个批次。在每个批次中:optimizer.zero_grad()
将优化器的梯度清零,避免梯度累加。outputs = model(images)
进行前向传播,得到模型的预测输出。loss = criterion(outputs, labels)
计算预测输出与真实标签之间的损失。loss.backward()
进行反向传播,计算损失对模型参数的梯度。optimizer.step()
根据计算得到的梯度更新模型的参数。running_loss += loss.item()
累加每个批次的损失。
- 打印训练信息:每一轮训练结束后,打印当前轮数和平均损失。
5. 评估模型
python
def evaluate(model, test_loader):
model.eval() # 设置为评估模式
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1) # 获取预测结果
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
evaluate(model, test_loader)
- 评估函数定义:
evaluate
函数接受模型和测试数据加载器作为参数。 - 设置评估模式:
model.eval()
将模型设置为评估模式,此时模型中的一些层(如 Dropout、BatchNorm 等)会处于评估状态。 - 禁用梯度计算:
with torch.no_grad()
上下文管理器禁用梯度计算,以提高评估效率,因为在评估阶段不需要计算梯度。 - 评估循环:遍历测试数据加载器中的每个批次,进行前向传播得到模型的预测输出,使用
torch.max(outputs.data, 1)
获取每个样本预测概率最大的类别索引,统计预测正确的样本数量和总样本数量。 - 计算准确率:根据统计结果计算测试集上的准确率,并打印出来。
6. 保存模型
python
torch.save(model.state_dict(), 'mnist_model.pth')
使用torch.save
函数将模型的参数保存到文件mnist_model.pth
中。model.state_dict()
返回模型的参数字典,后续可以使用torch.load
函数加载这些参数来恢复模型。
四、注意事项
- 代码中的神经网络模型结构相对简单,对于复杂的任务可能需要进一步优化和调整。
- 训练轮数
epochs
和学习率lr
等超参数可以根据实际情况进行调整,以获得更好的训练效果。 - 在实际应用中,可能需要进行更多的模型验证和调优,如使用验证集进行早停等。
- 确保运行代码的环境中已安装所需的依赖库,并且版本兼容。
完整代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 1. 加载和预处理数据
# 使用MNIST数据集(手写数字分类任务)
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 归一化到 [-1, 1]
])
# 下载并加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 2. 构建神经网络模型
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten() # 将图像展平
self.fc1 = nn.Linear(28 * 28, 128) # 全连接层1
self.fc2 = nn.Linear(128, 64) # 全连接层2
self.fc3 = nn.Linear(64, 10) # 输出层
def forward(self, x):
x = self.flatten(x)
x = torch.relu(self.fc1(x)) # 激活函数ReLU
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
model = NeuralNetwork()
# 3. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器
# 4. 训练模型
def train(model, train_loader, criterion, optimizer, epochs=5):
model.train() # 设置为训练模式
for epoch in range(epochs):
running_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad() # 梯度清零
outputs = model(images) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
running_loss += loss.item()
print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(train_loader):.4f}")
train(model, train_loader, criterion, optimizer, epochs=5)
# 5. 评估模型
def evaluate(model, test_loader):
model.eval() # 设置为评估模式
correct = 0
total = 0
with torch.no_grad(): # 禁用梯度计算
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1) # 获取预测结果
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")
evaluate(model, test_loader)
# 6. 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')