当前位置: 首页 > article >正文

回归任务训练--MNIST全连接神经网络(Mnist_NN)

import torch
import numpy as np
import logging
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import DataLoader

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 定义 loss_batch 函数,添加类型注解
def loss_batch(model: torch.nn.Module, loss_func: torch.nn.Module, xb: torch.Tensor, yb: torch.Tensor, opt=None) -> tuple:
    loss = loss_func(model(xb), yb)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
    return loss.item(), len(xb)

# 定义一个简单的模型,添加类型注解
class SimpleModel(torch.nn.Module):
    def __init__(self, input_dim: int):
        super(SimpleModel, self).__init__()
        self.fc = torch.nn.Linear(input_dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)

# 生成模拟数据
num_train_samples = 1000
input_dim = 10
x_train = torch.randn(num_train_samples, input_dim)
y_train = torch.randn(num_train_samples, 1)

# 数据标准化
x_train_mean = x_train.mean(dim=0)
x_train_std = x_train.std(dim=0)
x_train = (x_train - x_train_mean) / x_train_std

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)

num_valid_samples = 200
x_valid = torch.randn(num_valid_samples, input_dim)
y_valid = torch.randn(num_valid_samples, 1)

# 验证数据标准化
x_valid = (x_valid - x_train_mean) / x_train_std

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64)

# 初始化模型、损失函数和优化器
model = SimpleModel(input_dim)
loss_func = torch.nn.MSELoss()
opt = torch.optim.SGD(model.parameters(), lr=0.01)

# 定义 fit 函数,添加早停机制
def fit(steps: int, model: torch.nn.Module, loss_func: torch.nn.Module, opt: torch.optim.Optimizer, train_dl: DataLoader, valid_dl: DataLoader, patience: int = 3):
    best_val_loss = float('inf')
    no_improvement_count = 0
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        logging.info(f'当前step: {step}, 验证集损失:{val_loss}')

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvement_count = 0
        else:
            no_improvement_count += 1

        if no_improvement_count >= patience:
            logging.info(f'早停:验证集损失在 {patience} 个步骤内没有改善。')
            break

# 调用 fit 函数进行训练
fit(50, model, loss_func, opt, train_dl, valid_dl)

1. 库导入与日志配置

import torch
import numpy as np
import logging
from torch.utils.data import TensorDataset, DataLoader
  • 功能

    • 导入 PyTorch 深度学习框架和 NumPy 数值计算库。

    • 配置日志系统,记录训练过程中的关键信息。

  • 关键点

    • TensorDataset 和 DataLoader 用于封装数据集并实现批量加载。


2. 定义损失计算函数 loss_batch

def loss_batch(model: torch.nn.Module, loss_func: torch.nn.Module, 
              xb: torch.Tensor, yb: torch.Tensor, opt=None) -> tuple:
    loss = loss_func(model(xb), yb)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
    return loss.item(), len(xb)
  • 功能

    • 计算模型在单个批次的损失,并根据是否提供优化器 opt 决定是否更新模型参数。

  • 参数

    • model: 神经网络模型。

    • loss_func: 损失函数(如 MSELoss)。

    • xbyb: 输入数据和标签的批次。

    • opt: 优化器(如 SGD),若为 None 则仅计算损失。

  • 返回值

    • 当前批次的损失值(标量)和批次大小。


3. 定义简单线性模型 SimpleModel

class SimpleModel(torch.nn.Module):
    def __init__(self, input_dim: int):
        super(SimpleModel, self).__init__()
        self.fc = torch.nn.Linear(input_dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)
  • 功能

    • 定义一个单层全连接网络,用于回归任务。

  • 结构

    • 输入维度 input_dim,输出维度 1。

    • 示例:若 input_dim=10,则模型为 Linear(10 → 1)


4. 生成与标准化模拟数据

# 生成训练数据
num_train_samples = 1000
input_dim = 10
x_train = torch.randn(num_train_samples, input_dim)
y_train = torch.randn(num_train_samples, 1)
x_train_mean = x_train.mean(dim=0)
x_train_std = x_train.std(dim=0)
x_train = (x_train - x_train_mean) / x_train_std

# 创建 DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)

# 生成验证数据(使用训练数据的均值和标准差标准化)
num_valid_samples = 200
x_valid = torch.randn(num_valid_samples, input_dim)
y_valid = torch.randn(num_valid_samples, 1)
x_valid = (x_valid - x_train_mean) / x_train_std
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=64)
  • 功能

    • 生成随机训练和验证数据,并进行标准化处理。

  • 关键点

    • 标准化时使用训练数据的均值和标准差,避免数据泄露(Data Leakage)。

    • DataLoader 设置不同批量大小(训练 32,验证 64),提升训练效率。


5. 初始化模型、损失函数与优化器

model = SimpleModel(input_dim)
loss_func = torch.nn.MSELoss()
opt = torch.optim.SGD(model.parameters(), lr=0.01)
  • 功能

    • 定义回归任务的均方误差损失(MSE)和随机梯度下降优化器(SGD)。

  • 参数

    • 学习率 lr=0.01 控制参数更新步长。


6. 定义训练函数 fit 并实现早停机制

def fit(steps: int, model: torch.nn.Module, loss_func: torch.nn.Module, 
        opt: torch.optim.Optimizer, train_dl: DataLoader, 
        valid_dl: DataLoader, patience: int = 3):
    best_val_loss = float('inf')
    no_improvement_count = 0
    for step in range(steps):
        # 训练阶段
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)
        
        # 验证阶段
        model.eval()
        with torch.no_grad():
            losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        logging.info(f'当前step: {step}, 验证集损失:{val_loss}')
        
        # 早停判断
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvement_count = 0
        else:
            no_improvement_count += 1
        if no_improvement_count >= patience:
            logging.info(f'早停:验证集损失在 {patience} 个步骤内没有改善。')
            break
  • 功能

    • 执行模型训练,并在验证集上监控损失,实现早停机制。

  • 流程

    1. 训练模式:遍历训练数据,更新模型参数。

    2. 评估模式:关闭梯度计算,计算验证集损失。

    3. 早停机制:若连续 patience 次验证损失未改善,提前终止训练。

  • 关键点

    • model.train() 和 model.eval() 切换模型模式(影响 Dropout/BatchNorm 等层)。

    • 验证损失通过加权平均计算(考虑不同批次大小)。


7. 启动训练

fit(50, model, loss_func, opt, train_dl, valid_dl)
  • 参数

    • steps=50:最大训练轮数(epoch)。

    • patience=3:允许验证损失不改进的最大轮数。


代码总结

核心功能
  • 实现了一个简单的回归任务训练流程,包含数据标准化、模型训练、验证和早停机制。

  • 使用 PyTorch 的 DataLoader 实现高效数据加载,支持批量训练与验证。

 网络结构

改进建议
  1. 模型复杂度

    • 当前模型为单层线性网络,可增加隐藏层和非线性激活函数(如 ReLU)提升表达能力。

    self.fc1 = nn.Linear(input_dim, 64)
    self.fc2 = nn.Linear(64, 1)

  2. 学习率调度

    添加学习率衰减(如 torch.optim.lr_scheduler.StepLR)加速收敛。
  3. 数据增强

    若为真实数据,可添加噪声或变换增强鲁棒性。
  4. 日志增强

    记录训练损失和验证损失的曲线,便于分析过拟合。

http://www.kler.cn/a/595765.html

相关文章:

  • 构建教育类小程序:核心功能详解
  • SpringBoot @Scheduled注解详解
  • 浏览器工作原理深度解析(阶段二):HTML 解析与 DOM 树构建
  • Json的应用实例——cad 二次开发c#
  • 最近比突出的DeepSeek与ChatGPT的详细比较分析
  • 【k8s】利用Kubernetes卷快照实现高效的备份和恢复
  • C++具名转型的功能和用途
  • 基于SpringBoot的“校园招聘网站”的设计与实现(源码+数据库+文档+PPT)
  • 使用spring-ai-ollama访问本地化部署DeepSeek
  • 企业信息化的“双螺旋”——IT治理和数据治理
  • MySQL0基础学习记录-下载与安装
  • 光影香江聚四海,蓝陵科技扬帆数字内容新蓝海
  • 充分了解深度学习
  • Jsonpath使用
  • 游戏MOD伴随盗号风险,仿冒网站借“风灵月影”窃密【火绒企业版V2.0】
  • 【linux】防止SSD掉盘导致无法 reboot 软重启
  • Mysql表的简单操作
  • 嵌入式开发之STM32学习笔记day07
  • 基于RAGFlow本地部署DeepSeek-R1大模型与知识库:从配置到应用的全流程解析
  • UR5e机器人位姿