回归任务训练--MNIST全连接神经网络(Mnist_NN)
import torch
import numpy as np
import logging
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import DataLoader
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# 定义 loss_batch 函数,添加类型注解
def loss_batch(model: torch.nn.Module, loss_func: torch.nn.Module, xb: torch.Tensor, yb: torch.Tensor, opt=None) -> tuple:
loss = loss_func(model(xb), yb)
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
# 定义一个简单的模型,添加类型注解
class SimpleModel(torch.nn.Module):
def __init__(self, input_dim: int):
super(SimpleModel, self).__init__()
self.fc = torch.nn.Linear(input_dim, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc(x)
# 生成模拟数据
num_train_samples = 1000
input_dim = 10
x_train = torch.randn(num_train_samples, input_dim)
y_train = torch.randn(num_train_samples, 1)
# 数据标准化
x_train_mean = x_train.mean(dim=0)
x_train_std = x_train.std(dim=0)
x_train = (x_train - x_train_mean) / x_train_std
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
num_valid_samples = 200
x_valid = torch.randn(num_valid_samples, input_dim)
y_valid = torch.randn(num_valid_samples, 1)
# 验证数据标准化
x_valid = (x_valid - x_train_mean) / x_train_std
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64)
# 初始化模型、损失函数和优化器
model = SimpleModel(input_dim)
loss_func = torch.nn.MSELoss()
opt = torch.optim.SGD(model.parameters(), lr=0.01)
# 定义 fit 函数,添加早停机制
def fit(steps: int, model: torch.nn.Module, loss_func: torch.nn.Module, opt: torch.optim.Optimizer, train_dl: DataLoader, valid_dl: DataLoader, patience: int = 3):
best_val_loss = float('inf')
no_improvement_count = 0
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
logging.info(f'当前step: {step}, 验证集损失:{val_loss}')
if val_loss < best_val_loss:
best_val_loss = val_loss
no_improvement_count = 0
else:
no_improvement_count += 1
if no_improvement_count >= patience:
logging.info(f'早停:验证集损失在 {patience} 个步骤内没有改善。')
break
# 调用 fit 函数进行训练
fit(50, model, loss_func, opt, train_dl, valid_dl)
1. 库导入与日志配置
import torch
import numpy as np
import logging
from torch.utils.data import TensorDataset, DataLoader
-
功能:
-
导入 PyTorch 深度学习框架和 NumPy 数值计算库。
-
配置日志系统,记录训练过程中的关键信息。
-
-
关键点:
-
TensorDataset
和DataLoader
用于封装数据集并实现批量加载。
-
2. 定义损失计算函数 loss_batch
def loss_batch(model: torch.nn.Module, loss_func: torch.nn.Module,
xb: torch.Tensor, yb: torch.Tensor, opt=None) -> tuple:
loss = loss_func(model(xb), yb)
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
-
功能:
-
计算模型在单个批次的损失,并根据是否提供优化器
opt
决定是否更新模型参数。
-
-
参数:
-
model
: 神经网络模型。 -
loss_func
: 损失函数(如 MSELoss)。 -
xb
,yb
: 输入数据和标签的批次。 -
opt
: 优化器(如 SGD),若为None
则仅计算损失。
-
-
返回值:
-
当前批次的损失值(标量)和批次大小。
-
3. 定义简单线性模型 SimpleModel
class SimpleModel(torch.nn.Module):
def __init__(self, input_dim: int):
super(SimpleModel, self).__init__()
self.fc = torch.nn.Linear(input_dim, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc(x)
-
功能:
-
定义一个单层全连接网络,用于回归任务。
-
-
结构:
-
输入维度
input_dim
,输出维度 1。 -
示例:若
input_dim=10
,则模型为Linear(10 → 1)
。
-
4. 生成与标准化模拟数据
# 生成训练数据
num_train_samples = 1000
input_dim = 10
x_train = torch.randn(num_train_samples, input_dim)
y_train = torch.randn(num_train_samples, 1)
x_train_mean = x_train.mean(dim=0)
x_train_std = x_train.std(dim=0)
x_train = (x_train - x_train_mean) / x_train_std
# 创建 DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
# 生成验证数据(使用训练数据的均值和标准差标准化)
num_valid_samples = 200
x_valid = torch.randn(num_valid_samples, input_dim)
y_valid = torch.randn(num_valid_samples, 1)
x_valid = (x_valid - x_train_mean) / x_train_std
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64)
-
功能:
-
生成随机训练和验证数据,并进行标准化处理。
-
-
关键点:
-
标准化时使用训练数据的均值和标准差,避免数据泄露(Data Leakage)。
-
DataLoader
设置不同批量大小(训练 32,验证 64),提升训练效率。
-
5. 初始化模型、损失函数与优化器
model = SimpleModel(input_dim)
loss_func = torch.nn.MSELoss()
opt = torch.optim.SGD(model.parameters(), lr=0.01)
-
功能:
-
定义回归任务的均方误差损失(MSE)和随机梯度下降优化器(SGD)。
-
-
参数:
-
学习率
lr=0.01
控制参数更新步长。
-
6. 定义训练函数 fit
并实现早停机制
def fit(steps: int, model: torch.nn.Module, loss_func: torch.nn.Module,
opt: torch.optim.Optimizer, train_dl: DataLoader,
valid_dl: DataLoader, patience: int = 3):
best_val_loss = float('inf')
no_improvement_count = 0
for step in range(steps):
# 训练阶段
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
# 验证阶段
model.eval()
with torch.no_grad():
losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
logging.info(f'当前step: {step}, 验证集损失:{val_loss}')
# 早停判断
if val_loss < best_val_loss:
best_val_loss = val_loss
no_improvement_count = 0
else:
no_improvement_count += 1
if no_improvement_count >= patience:
logging.info(f'早停:验证集损失在 {patience} 个步骤内没有改善。')
break
-
功能:
-
执行模型训练,并在验证集上监控损失,实现早停机制。
-
-
流程:
-
训练模式:遍历训练数据,更新模型参数。
-
评估模式:关闭梯度计算,计算验证集损失。
-
早停机制:若连续
patience
次验证损失未改善,提前终止训练。
-
-
关键点:
-
model.train()
和model.eval()
切换模型模式(影响 Dropout/BatchNorm 等层)。 -
验证损失通过加权平均计算(考虑不同批次大小)。
-
7. 启动训练
fit(50, model, loss_func, opt, train_dl, valid_dl)
-
参数:
-
steps=50
:最大训练轮数(epoch)。 -
patience=3
:允许验证损失不改进的最大轮数。
-
代码总结
核心功能
-
实现了一个简单的回归任务训练流程,包含数据标准化、模型训练、验证和早停机制。
-
使用 PyTorch 的
DataLoader
实现高效数据加载,支持批量训练与验证。
网络结构
改进建议
-
模型复杂度:
-
当前模型为单层线性网络,可增加隐藏层和非线性激活函数(如 ReLU)提升表达能力。
self.fc1 = nn.Linear(input_dim, 64) self.fc2 = nn.Linear(64, 1)
-
-
学习率调度:
添加学习率衰减(如torch.optim.lr_scheduler.StepLR
)加速收敛。 -
数据增强:
若为真实数据,可添加噪声或变换增强鲁棒性。 -
日志增强:
记录训练损失和验证损失的曲线,便于分析过拟合。