当前位置: 首页 > article >正文

二、模型训练与优化(4):模型优化-实操

下面我将以 MNIST 手写数字识别模型为例,从 剪枝 (Pruning)量化 (Quantization) 两个常用方法出发,提供一套可实际动手操作的模型优化流程。此示例基于 TensorFlow/Keras 环境,示范如何先训练一个基础模型,然后对其进行剪枝和量化,最后验证优化后的模型性能。


目录

  1. 整体流程概览
  2. 模型剪枝 (Pruning)
    1. 安装依赖库
    2. 修改训练脚本实现剪枝
    3. 如何运行剪枝脚本
    4. 检查与验证剪枝后模型
  3. 模型量化 (Quantization)
    1. 原理与应用场景
    2. 在脚本中添加量化步骤
    3. 运行量化脚本
    4. 验证量化后模型
  4. 常见问题与建议
  5. 总结

1. 整体流程概览

在之前博客中已经可以训练一个基础 MNIST 模型(train_mnist.py)并成功获得 mnist_model.h5 的前提下,通常会按照以下顺序进行优化:

在模型训练好后,可以在mnist_project文件夹下找到mnist_model.h5,如下:

  1. 剪枝 (Pruning):减小模型大小、去除不重要的权重,生成 pruned_mnist_model.h5
  2. (可选)量化 (Quantization):将浮点模型转化为 INT8 等低比特模型,大幅减小模型体积,并提升推理速度,生成 mnist_model_quant.tflite

在此过程中,我们需要:

  • 修改已有脚本新增脚本来执行剪枝和量化的操作。
  • 确保虚拟环境已安装必要库tensorflow-model-optimizationtensorflow-lite等)。
  • 反复验证模型的大小、推理速度、准确率,找到最适合部署需求的平衡点。

2. 模型剪枝 (Pruning)

2.1 安装依赖库

  • TensorFlow Model Optimization Toolkit:其中包含 tfmot.sparsity.keras 模块,可用于剪枝、量化感知训练等。

在激活的虚拟环境(tf_env 等)下,输入:

pip install tensorflow-model-optimization

如果已经安装过,可以跳过此步骤;若版本较旧,建议 pip install --upgrade tensorflow-model-optimization

2.2 修改训练脚本实现剪枝

这里给出的示例代码可放在一个新的脚本(如 prune_mnist.py),或者在原 train_mnist.py 中替换。示例如下:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot

def main():
    # 1. 加载 MNIST 数据集
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # 2. 数据预处理
    x_train = x_train.astype("float32") / 255.0
    x_test  = x_test.astype("float32") / 255.0
    x_train = x_train.reshape(-1, 28 * 28)
    x_test  = x_test.reshape(-1, 28 * 28)

    # 3. 定义剪枝参数
    pruning_params = {
        # PolynomialDecay 让剪枝率从 initial_sparsity 到 final_sparsity 逐渐增加
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.0,     # 初始剪枝率 (0%)
            final_sparsity=0.5,       # 最终剪枝率 (50%)
            begin_step=0,             # 剪枝开始 step
            end_step=np.ceil(len(x_train) / 64).astype(np.int32) * 5
            # end_step: 这里相当于 epochs * (训练集样本数 / batch_size)
        )
    }

    # 4. 构建剪枝后的模型
    #   - 先定义一个包含1~2层的网络
    #   - 使用 prune_low_magnitude 对最后一层进行剪枝封装
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Dense(10, activation='softmax'),
            **pruning_params
        )
    ])

    # 5. 编译模型
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # 6. 设置剪枝回调
    #    - UpdatePruningStep:在每个批次/epoch后更新剪枝进度
    #    - PruningSummaries:可选,将剪枝信息写入到指定 log_dir,配合 TensorBoard 查看
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir='logs')
    ]

    # 7. 训练模型
    #    - epochs=5 可以根据需要加大或减少
    history = model.fit(
        x_train, y_train,
        epochs=5,
        batch_size=64,
        validation_split=0.1,
        callbacks=callbacks
    )

    # 8. 模型评估
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
    print(f"\n测试集上的准确率: {test_acc:.4f}")

    # 9. 保存剪枝后的模型
    #    - 先使用 strip_pruning 去除剪枝包装器,得到最终“瘦身”模型
    final_model = tfmot.sparsity.keras.strip_pruning(model)
    final_model.save("pruned_mnist_model.h5")

    # 10. 可视化训练过程
    plot_history(history)

def plot_history(history):
    """可视化训练曲线"""
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 4))

    # 绘制准确率曲线
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='训练准确率')
    plt.plot(epochs_range, val_acc, label='验证准确率')
    plt.legend(loc='lower right')
    plt.title('训练和验证准确率')

    # 绘制损失曲线
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='训练损失')
    plt.plot(epochs_range, val_loss, label='验证损失')
    plt.legend(loc='upper right')
    plt.title('训练和验证损失')

    plt.show()

if __name__ == "__main__":
    main()
代码要点
  1. tfmot.sparsity.keras.PolynomialDecay:定义从 0% 到 50% 的剪枝率逐渐增加的策略。
  2. prune_low_magnitude(...):对目标层进行剪枝包装。可以只对某些关键层做剪枝,也可对网络所有层做封装。
  3. strip_pruning(...):剪枝训练完后,需要去掉剪枝相关的“假”节点,才能得到真正稀疏的权重以减小体积。

2.3 如何运行剪枝脚本

  1. 确保已经训练过一个基础模型(可选,如果想微调原模型);或者像示例这样直接在脚本里构建一个新的网络。
  2. 打开 Anaconda Prompt(或终端),激活虚拟环境
    conda activate tf_env
    
  3. 导航到脚本所在目录
    cd C:\Users\FCZ\Desktop\Projects\mnist_project
    
  4. 运行脚本
    python prune_mnist.py
    

训练过程结束后,会打印出测试集准确率,并在目录下生成 pruned_mnist_model.h5

2.4 检查与验证剪枝后模型

  1. 模型体积:相较原始不剪枝模型,pruned_mnist_model.h5 通常会更小,但因 HDF5 格式本身包含稀疏权重的表示方式,实际文件大小并不总是线性减少。关键是剪枝会让权重矩阵变得稀疏,后续可以配合特定框架(如 STM32Cube.AI)进行再处理。
  2. 准确率:可能略有降低,一般会在 0.97~0.98 附近。若下降过多,可调整 final_sparsity (如从 0.5 改为 0.3) 或增加微调 epochs。
  3. 后续可做量化:将剪枝后模型再进行量化,可实现进一步体积和推理速度的提升。

3. 模型量化 (Quantization)

3.1 原理与应用场景

  • 量化:把模型中的权重(和激活)从 float32 转化成 int8、float16 等低位格式,典型方式是使用 TensorFlow Lite 的离线量化。
  • 适用场景:需要在嵌入式或移动端部署,同时希望降低模型大小和加速推理。
  • 代价:可能带来少量精度损失。如果需要减小精度损失,可用量化感知训练(QAT)。

3.2 在脚本中添加量化步骤

当我们在 train_mnist.py 训练完基础模型后,在prune_mnist.py完成剪枝操作后,接下来完成量化操作,编写单独脚本 quantize_mnist.py:将 训练、剪枝、量化 三个步骤整合在一起

"""
quantize_mnist.py
-----------------
在同一个脚本中完成:
1. MNIST 基础模型训练
2. 剪枝 (Pruning)
3. 量化 (Quantization)

依赖:
  - tensorflow>=2.5
  - tensorflow-model-optimization
  - numpy, matplotlib (可选, 用于可视化)
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_model_optimization as tfmot

def load_mnist_data():
    """加载 MNIST 数据,并做基本预处理。"""
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train = x_train.astype("float32") / 255.0
    x_test  = x_test.astype("float32") / 255.0

    # 展开 28x28 -> 784
    x_train = x_train.reshape(-1, 28 * 28)
    x_test  = x_test.reshape(-1, 28 * 28)
    return (x_train, y_train), (x_test, y_test)

def create_base_model():
    """构建一个简单的全连接 MNIST 模型。"""
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

def plot_history(history, title_prefix=""):
    """可视化训练曲线"""
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(12, 4))

    # 准确率曲线
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='训练准确率')
    plt.plot(epochs_range, val_acc, label='验证准确率')
    plt.legend(loc='lower right')
    plt.title(f'{title_prefix} 训练和验证准确率')

    # 损失曲线
    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='训练损失')
    plt.plot(epochs_range, val_loss, label='验证损失')
    plt.legend(loc='upper right')
    plt.title(f'{title_prefix} 训练和验证损失')

    plt.show()

def main():
    # =======================================
    # 1. 数据准备
    # =======================================
    (x_train, y_train), (x_test, y_test) = load_mnist_data()

    # =======================================
    # 2. 训练基线模型
    # =======================================
    print("\n--- 步骤1: 训练基线模型 ---")
    base_model = create_base_model()
    base_model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    history_base = base_model.fit(
        x_train, y_train,
        epochs=5,
        batch_size=64,
        validation_split=0.1
    )

    test_loss_base, test_acc_base = base_model.evaluate(x_test, y_test, verbose=0)
    print(f"基线模型测试集准确率: {test_acc_base:.4f}")

    # 可视化基线模型训练过程
    plot_history(history_base, title_prefix="基线模型")

    # 保存基线模型
    base_model.save("mnist_model.h5")

    # =======================================
    # 3. 剪枝 (Pruning)
    # =======================================
    print("\n--- 步骤2: 剪枝模型 ---")
    # 定义剪枝参数:从0%渐增到50%的剪枝率
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.0,
            final_sparsity=0.5,
            begin_step=0,
            end_step=np.ceil(len(x_train) / 64).astype(np.int32) * 5
        )
    }

    # 用之前的 base_model 权重来构造可剪枝模型
    # 也可直接对 base_model 做 prune_low_magnitude,但这里分开写更清晰
    pruned_model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Dense(10, activation='softmax'),
            **pruning_params
        )
    ])

    # 把 base_model 的第一层权重复制到 pruned_model 第1层
    pruned_model.layers[0].set_weights(base_model.layers[0].get_weights())

    # 编译可剪枝模型
    pruned_model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # 设置回调:更新剪枝步数 + 记录日志
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir='logs')
    ]

    history_pruned = pruned_model.fit(
        x_train, y_train,
        epochs=3,          # 可以适当增加训练轮数
        batch_size=64,
        validation_split=0.1,
        callbacks=callbacks
    )

    test_loss_pruned, test_acc_pruned = pruned_model.evaluate(x_test, y_test, verbose=0)
    print(f"剪枝后模型测试集准确率: {test_acc_pruned:.4f}")
    plot_history(history_pruned, title_prefix="剪枝模型")

    # strip_pruning: 得到真正稀疏的权重
    final_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
    final_pruned_model.save("pruned_mnist_model.h5")

    # =======================================
    # 4. 量化 (Quantization)
    # =======================================
    print("\n--- 步骤3: 量化剪枝后模型 (PTQ) ---")
    # 您也可以对 base_model 做量化,这里演示对 剪枝后的模型 做量化
    converter = tf.lite.TFLiteConverter.from_keras_model(final_pruned_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    # 如需要 representative_dataset 来校准,可添加:
    # converter.representative_dataset = ...

    # 转换为 TFLite
    tflite_quant_model = converter.convert()

    # 保存量化后的 TFLite 文件
    with open('pruned_mnist_model_quant.tflite', 'wb') as f:
        f.write(tflite_quant_model)

    print("量化后的剪枝模型已保存: pruned_mnist_model_quant.tflite")

    # 如有需要,可使用 tflite interpreter 测试推理
    # 这里仅演示到生成 TFLite 文件即可

if __name__ == "__main__":
    main()
注意
  • 量化完成后,记得在 PC 或嵌入式设备上进行推理测试,查看最终精度。

3.3 运行量化脚本

  1. 依旧在 Anaconda Prompt 中激活环境conda activate tf_env
  2. 导航到脚本所在目录
  3. 执行: python quantize_mnist.py
  4. 观察输出:若无异常,脚本会提示 "量化后的模型已保存为 mnist_model_quant.tflite"
  • 结果文件

  •          mnist_model.h5:基线模型(未剪枝、未量化)。

       pruned_mnist_model.h5:剪枝后且 strip_pruning 的 Keras 模型。

       pruned_mnist_model_quant.tflite:剪枝后再量化的 TFLite 模型,通常体积最小,速度也更快(具体依赖硬件支持)。

  • 训练基线模型:训练 5 轮得到 mnist_model.h5
  • 剪枝模型:基于基线模型的权重进行剪枝,训练 3 轮得到 pruned_mnist_model.h5
  • 量化模型:将剪枝后的模型转换为 .tflite 格式,并保存为 pruned_mnist_model_quant.tflite

总结

接下来对优化后的模型进行验证。


http://www.kler.cn/a/481817.html

相关文章:

  • 你好,2025!JumpServer开启新十年
  • ADO.NET知识总结4---SqlParameter参数
  • 每日一题-两个链表的第一个公共结点
  • Mysql - 多表连接和连接类型
  • 小程序组件 —— 30 组件 - 背景图片的使用
  • Python学习笔记:显示进度条
  • ip属地出省会变吗?怎么出省让ip属地不变
  • spring mvc源码学习笔记之十
  • 【蓝桥杯选拔赛真题60】C++寻宝石 第十四届蓝桥杯青少年创意编程大赛 算法思维 C++编程选拔赛真题解
  • Java 锁:多线程环境下的同步机制
  • 深度学习概述
  • 【Three.js基础学习】34.Earch Shaders
  • Redis 管道技术(Pipeline)
  • 2025新春烟花代码(二)HTML5实现孔明灯和烟花效果
  • 源代码防泄漏一机两用合体方案
  • 芯片详细讲解,从而区分CPU、MPU、DSP、GPU、FPGA、MCU、SOC、ECU
  • 数据结构:LinkedList与链表—无头单向链表(一)
  • 解决OPenMP不能使用头文件#include <omp.h> 的问题
  • SQLite PRAGMA
  • LQ quarter 5th
  • 缓存-Redis-缓存更新策略-主动更新策略-Cache Aside Pattern(全面 易理解)
  • OpenAI 宣称已掌握构建通用人工智能 (AGI) 的方法| 0107AI日报
  • Docker--Docker Volume(存储卷)
  • 【Java 注解】从入门到精通:上篇
  • 安装完docker后,如何拉取ubuntu镜像并创建容器?
  • 机器学习无处不在,AI顺势而为,创新未来