当前位置: 首页 > article >正文

深度学习blog-剪枝和知识蒸馏

深度学习网络模型从卷积层到全连接层存在着大量冗余的参数,大量神经元激活值趋近于0,将这些神经元去除后可以表现出同样的模型表达能力,这种情况被称为过参数化。因此需要一些技术手段减少模型的复杂性,去除一些不重要的参数和连接,从而提高模型在推理阶段的效率,减少存储需求,同时可能还能够降低过拟合的风险。

常用的模型的压缩和轻量化加速技术有:

  1. 权重剪枝:通过删除神经网络中冗余的权重来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断权重的重要性,然后将不重要的权重设置为零或删除。
  2. 模型量化:将神经网络中的权重和激活值从浮点数转换为低精度的整数表示,从而减少模型的存储空间和计算量。
  3. 知识蒸馏(Knowledge Distillation):这是一种特殊的模型蒸馏技术,其中教师模型和学生模型具有相同的架构,但参数不同。通过让学生模型学习教师模型的输出,可以实现模型的压缩和加速。
  4. 知识提炼(Knowledge Carving):选择性地从教师模型中抽取部分子结构用于构建学生模型。
  5. 网络剪枝(Network Pruning):通过删除神经网络中冗余的神经元或连接来减少模型的复杂度和计算量。具体来说,可以通过设定一个阈值来判断神经元或连接的重要性,然后将不重要的神经元或连接删除。
  6. 低秩分解(Low-Rank Factorization):将神经网络中的权重矩阵分解为两个低秩矩阵的乘积,从而减少模型的存储空间和计算量。这种方法可以应用于卷积层和全连接层等不同类型的神经网络层。
  7. 结构搜索(Neural Architecture Search):通过自动搜索最优的神经网络结构来实现模型的压缩和加速。这种方法可以根据特定任务的需求来定制适合的神经网络结构。

剪枝(Pruning)是深度学习和神经网络中常用的一种模型压缩技术。

1. 剪枝的背景
深度学习模型通常由大量的参数组成,尤其是在深层神经网络中,这些参数使得模型能力强大,但也导致计算和存储成本高。为了在工业应用中将模型部署到资源有限的设备上,剪枝成为了重要的研究方向。

模型剪枝主要分为结构化剪枝非结构化剪枝非结构化剪枝去除不重要的神经元,相应地,被剪除的神经元和其他神经元之间的连接在计算时会被忽略。由于剪枝后的模型通常很稀疏,并且破坏了原有模型的结构,所以这类方法被称为非结构化剪枝

2. 剪枝粒度分类
 

  •  细粒度剪枝(fine-grained):即对连接或者神经元进行剪枝,它是粒度最小的剪枝。
  • 向量剪枝(vector-level):它相对于细粒度剪枝粒度更大,属于对卷积核内部(intra-kernel)的剪枝。
  • 核剪枝(kernel-level):即去除某个卷积核,它将丢弃对输入通道中对应计算通道的响应。
  • 滤波器剪枝(Filter-level):对整个卷积核组进行剪枝,会造成推理过程中输出特征通道数的改变。

3. 剪枝的流程
剪枝的基本流程通常包括以下几个步骤:

训练:首先对神经网络进行训练,直到达到满意的精度。
评估重要性:使用特定的标准(如权重的绝对值、梯度等)来评估每个参数或神经元的重要性。
剪枝:根据重要性评估结果,去除一些参数或神经元。
微调(Fine-tuning):对剪枝后的模型进行再训练,以恢复模型的性能。

4. 剪枝的优缺点
优点:
减少计算复杂度:剪枝后,模型推理速度更快。
降低存储需求:模型所需的存储空间减少。
提高模型泛化能力:可能减少过拟合,并提高在新数据上的表现。
缺点:
剪枝带来的性能损失:不当的剪枝可能导致模型精度下降。
额外的复杂性:剪枝和微调过程增加了模型训练的复杂性。


5. 实际应用
剪枝技术已广泛应用于移动设备、边缘计算和实时应用中,例如图像识别、自然语言处理等任务,很多现代深度学习框架(如TensorFlow、PyTorch)都有包含剪枝的相关工具和库。

知识蒸馏(Knowledge Distillation)

是一种模型压缩技术,旨在将大型深度学习模型(通常称为“教师模型”)中的知识转移到较小的模型(称为“学生模型”)中。这种技术在计算资源有限的环境下尤为重要,因为它可以提高推理速度并减少模型的存储需求。
知识蒸馏主要包括以下几个步骤:

训练教师模型:首先,训练一个大型性能良好的教师模型。此模型通常在大规模数据集上经过充分训练,能够非常有效地捕捉数据中的复杂模式。

生成软标签:教师模型在训练集上预测的输出称为“软标签”。软标签包含了每个类别的概率分布,相比于传统的硬标签(one-hot编码),它保留了更多的信息。

训练学生模型:使用教师模型的预测(软标签)来训练更小的学生模型。在训练过程中,学生模型将学习教师模型的预测分布,而不仅仅是目标类别。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F
from ultralytics import YOLO

# 假设教师模型和学生模型已经定义
teacher_model = YOLO('yolo11n.pt').to(device)  
teacher_model.eval()  # 设置为评估模式
student_model = StudentModel().to(device)

# 知识蒸馏损失
def distillation_loss(y_true, y_pred, teacher_output, T, alpha):
    """
    distillation loss = alpha * cross_entropy_loss(y_true, y_pred) + (1 - alpha) * KL_divergence(softmax(teacher_output / T), softmax(y_pred / T))
    """
    loss_ce = F.cross_entropy(y_pred, y_true)
    loss_kl = F.kl_div(F.log_softmax(y_pred / T, dim=1), F.softmax(teacher_output / T, dim=1), reduction='batchmean')
    return alpha * loss_ce + (1 - alpha) * loss_kl

# 训练学生模型
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

# train
for epoch in range(num_epochs):
    student_model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        # 使用教师模型对训练数据生成软标签
        teacher_output = teacher_model(images)
        student_output = student_model(images)
        loss = distillation_loss(labels, student_output, teacher_output, T=2, alpha=0.7)
        loss.backward()
        optimizer. Step()

分类:

离线蒸馏,大型教师模型蒸馏前在训练样本训练;教师模型以logits或中间特征的形式提取知识,将其在蒸馏过程中指导学生模型的训练。教师的结构是预定义的,很少关注教师模型的结构及其与学生模型的关系。例如上面的蒸馏,使用预训练的权重作为教师模型。

在线蒸馏:教师模型和学生模型同步更新,而整个知识蒸馏框架都是端到端可训练的。

自蒸馏:教师和学生模型使用相同的网络,这可以看作是在线蒸馏的一个特例。

蒸馏算法,有基于注意力蒸馏,基于图蒸馏,基于生成对抗网络GAN蒸馏,量化蒸馏等等。


http://www.kler.cn/a/506186.html

相关文章:

  • 工作中redis常用的5种场景
  • Redis集群部署详解:主从复制、Sentinel哨兵模式与Cluster集群的工作原理与配置
  • 归子莫的科技周刊#2:白天搬砖,夜里读诗
  • 45_Lua模块与包
  • windows远程桌面连接限定ip
  • 【GIS操作】使用ArcGIS Pro进行海图的地理配准(附:墨卡托投影对比解析)
  • 强化学习的数学原理(七-3)TD算法总结
  • PHP中的魔术函数
  • SpringMVC Idea 搭建 部署war
  • 【React】插槽渲染机制
  • openharmony应用开发快速入门
  • Go实现设计模式
  • Python语言的编程范式
  • C++(二十)
  • 在 Azure 100 学生订阅中新建 Ubuntu VPS 并部署 Mastodon 服务器
  • 使用 `npm install` 时遇到速度很慢的问题
  • .Net MVC中视图的View()的具体用法
  • 【JavaScript】比较运算符的运用、定义函数、if(){}...esle{} 语句
  • Java中的高效集合操作:Stream API实战指南
  • 【2024年华为OD机试】(B卷,100分)- 数据分类 (Java JS PythonC/C++)
  • 使用python+pytest+requests完成自动化接口测试(包括html报告的生成和日志记录以及层级的封装(包括调用Json文件))
  • 浅谈云计算14 | 云存储技术
  • Windows图形界面(GUI)-QT-C/C++ - QT 对话窗口
  • python flask简单实践
  • 谷歌浏览器的兼容性与性能优化策略
  • MySQL程序之:使用选项设置程序变量